Skip to content

Commit c6291d5

Browse files
authored
use streaming AudioDecoder to handle compressed formats (#1584)
1 parent eaf112c commit c6291d5

File tree

23 files changed

+346
-478
lines changed

23 files changed

+346
-478
lines changed

.changeset/big-cars-join.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
---
2+
"livekit-plugins-elevenlabs": minor
3+
"livekit-plugins-deepgram": minor
4+
"livekit-plugins-cartesia": patch
5+
"livekit-plugins-google": patch
6+
"livekit-plugins-openai": patch
7+
"livekit-plugins-playai": patch
8+
"livekit-plugins-rime": patch
9+
"livekit-plugins-aws": patch
10+
"livekit-agents": patch
11+
---
12+
13+
use streaming AudioDecoder to handle compressed encoding

livekit-agents/livekit/agents/tts/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
TTS,
1010
ChunkedStream,
1111
SynthesizedAudio,
12+
SynthesizedAudioEmitter,
1213
SynthesizeStream,
1314
TTSCapabilities,
1415
)
@@ -25,4 +26,5 @@
2526
"FallbackAdapter",
2627
"FallbackChunkedStream",
2728
"FallbackSynthesizeStream",
29+
"SynthesizedAudioEmitter",
2830
]

livekit-agents/livekit/agents/tts/tts.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,48 @@ async def __aexit__(
430430
exc_tb: TracebackType | None,
431431
) -> None:
432432
await self.aclose()
433+
434+
435+
class SynthesizedAudioEmitter:
436+
"""Utility for buffering and emitting audio frames with metadata to a channel.
437+
438+
This class helps TTS implementers to correctly handle is_final logic when streaming responses.
439+
"""
440+
441+
def __init__(
442+
self,
443+
*,
444+
event_ch: aio.Chan[SynthesizedAudio],
445+
request_id: str,
446+
segment_id: str = "",
447+
) -> None:
448+
self._event_ch = event_ch
449+
self._frame: rtc.AudioFrame | None = None
450+
self._request_id = request_id
451+
self._segment_id = segment_id
452+
453+
def push(self, frame: Optional[rtc.AudioFrame]):
454+
"""Emits any buffered frame and stores the new frame for later emission.
455+
456+
The buffered frame is emitted as not final.
457+
"""
458+
self._emit_frame(is_final=False)
459+
self._frame = frame
460+
461+
def _emit_frame(self, is_final: bool = False):
462+
"""Sends the buffered frame to the event channel if one exists."""
463+
if self._frame is None:
464+
return
465+
self._event_ch.send_nowait(
466+
SynthesizedAudio(
467+
frame=self._frame,
468+
request_id=self._request_id,
469+
segment_id=self._segment_id,
470+
is_final=is_final,
471+
)
472+
)
473+
self._frame = None
474+
475+
def flush(self):
476+
"""Emits any buffered frame as final."""
477+
self._emit_frame(is_final=True)

livekit-agents/livekit/agents/utils/codecs/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,5 @@
1313
# limitations under the License.
1414

1515
from .decoder import AudioStreamDecoder, StreamBuffer
16-
from .mp3 import Mp3StreamDecoder
1716

18-
__all__ = ["Mp3StreamDecoder", "AudioStreamDecoder", "StreamBuffer"]
17+
__all__ = ["AudioStreamDecoder", "StreamBuffer"]

livekit-agents/livekit/agents/utils/codecs/mp3.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

livekit-agents/livekit/agents/utils/connection_pool.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,20 @@ def __init__(
2424
self,
2525
*,
2626
max_session_duration: Optional[float] = None,
27+
mark_refreshed_on_get: bool = False,
2728
connect_cb: Optional[Callable[[], Awaitable[T]]] = None,
2829
close_cb: Optional[Callable[[T], Awaitable[None]]] = None,
2930
) -> None:
3031
"""Initialize the connection wrapper.
3132
3233
Args:
3334
max_session_duration: Maximum duration in seconds before forcing reconnection
35+
mark_refreshed_on_get: If True, the session will be marked as fresh when get() is called. only used when max_session_duration is set.
3436
connect_cb: Optional async callback to create new connections
3537
close_cb: Optional async callback to close connections
3638
"""
3739
self._max_session_duration = max_session_duration
40+
self._mark_refreshed_on_get = mark_refreshed_on_get
3841
self._connect_cb = connect_cb
3942
self._close_cb = close_cb
4043
self._connections: dict[T, float] = {} # conn -> connected_at timestamp
@@ -95,6 +98,8 @@ async def get(self) -> T:
9598
self._max_session_duration is None
9699
or now - self._connections[conn] <= self._max_session_duration
97100
):
101+
if self._mark_refreshed_on_get:
102+
self._connections[conn] = now
98103
return conn
99104
# connection expired; mark it for resetting.
100105
self.remove(conn)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
import logging
22

33
logger = logging.getLogger("livekit.plugins.aws")
4+
for logger_name in ["botocore", "aiobotocore"]:
5+
logging.getLogger(logger_name).setLevel(logging.INFO)

livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@
4545
"de-CH",
4646
]
4747

48-
TTS_OUTPUT_FORMAT = Literal["pcm", "mp3"]
48+
TTS_OUTPUT_FORMAT = Literal["mp3"]

livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/tts.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import aiohttp
2020
from aiobotocore.session import AioSession, get_session
21-
from livekit import rtc
2221
from livekit.agents import (
2322
APIConnectionError,
2423
APIConnectOptions,
@@ -29,10 +28,9 @@
2928
)
3029

3130
from ._utils import _get_aws_credentials
32-
from .models import TTS_LANGUAGE, TTS_OUTPUT_FORMAT, TTS_SPEECH_ENGINE
31+
from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
3332

3433
TTS_NUM_CHANNELS: int = 1
35-
DEFAULT_OUTPUT_FORMAT: TTS_OUTPUT_FORMAT = "pcm"
3634
DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
3735
DEFAULT_SPEECH_REGION = "us-east-1"
3836
DEFAULT_VOICE = "Ruth"
@@ -43,7 +41,6 @@
4341
class _TTSOptions:
4442
# https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
4543
voice: str | None
46-
output_format: TTS_OUTPUT_FORMAT
4744
speech_engine: TTS_SPEECH_ENGINE
4845
speech_region: str
4946
sample_rate: int
@@ -56,7 +53,6 @@ def __init__(
5653
*,
5754
voice: str | None = DEFAULT_VOICE,
5855
language: TTS_LANGUAGE | str | None = None,
59-
output_format: TTS_OUTPUT_FORMAT = DEFAULT_OUTPUT_FORMAT,
6056
speech_engine: TTS_SPEECH_ENGINE = DEFAULT_SPEECH_ENGINE,
6157
sample_rate: int = DEFAULT_SAMPLE_RATE,
6258
speech_region: str = DEFAULT_SPEECH_REGION,
@@ -75,7 +71,6 @@ def __init__(
7571
Args:
7672
Voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
7773
language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
78-
output_format(TTS_OUTPUT_FORMAT, optional): The format in which the returned output will be encoded. Defaults to "pcm".
7974
sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
8075
speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
8176
speech_region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
@@ -96,7 +91,6 @@ def __init__(
9691

9792
self._opts = _TTSOptions(
9893
voice=voice,
99-
output_format=output_format,
10094
speech_engine=speech_engine,
10195
speech_region=speech_region,
10296
language=language,
@@ -149,7 +143,7 @@ async def _run(self):
149143
async with self._get_client() as client:
150144
params = {
151145
"Text": self._input_text,
152-
"OutputFormat": self._opts.output_format,
146+
"OutputFormat": "mp3",
153147
"Engine": self._opts.speech_engine,
154148
"VoiceId": self._opts.voice,
155149
"TextType": "text",
@@ -158,32 +152,36 @@ async def _run(self):
158152
}
159153
response = await client.synthesize_speech(**_strip_nones(params))
160154
if "AudioStream" in response:
161-
decoder = utils.codecs.Mp3StreamDecoder()
162-
async with response["AudioStream"] as resp:
163-
async for data, _ in resp.content.iter_chunks():
164-
if self._opts.output_format == "mp3":
165-
frames = decoder.decode_chunk(data)
166-
for frame in frames:
167-
self._event_ch.send_nowait(
168-
tts.SynthesizedAudio(
169-
request_id=request_id,
170-
segment_id=self._segment_id,
171-
frame=frame,
172-
)
173-
)
174-
else:
175-
self._event_ch.send_nowait(
176-
tts.SynthesizedAudio(
177-
request_id=request_id,
178-
segment_id=self._segment_id,
179-
frame=rtc.AudioFrame(
180-
data=data,
181-
sample_rate=self._opts.sample_rate,
182-
num_channels=1,
183-
samples_per_channel=len(data) // 2,
184-
),
185-
)
186-
)
155+
decoder = utils.codecs.AudioStreamDecoder(
156+
sample_rate=self._opts.sample_rate,
157+
num_channels=1,
158+
)
159+
160+
# Create a task to push data to the decoder
161+
async def push_data():
162+
try:
163+
async with response["AudioStream"] as resp:
164+
async for data, _ in resp.content.iter_chunks():
165+
decoder.push(data)
166+
finally:
167+
decoder.end_input()
168+
169+
# Start pushing data to the decoder
170+
push_task = asyncio.create_task(push_data())
171+
172+
try:
173+
# Create emitter and process decoded frames
174+
emitter = tts.SynthesizedAudioEmitter(
175+
event_ch=self._event_ch,
176+
request_id=request_id,
177+
segment_id=self._segment_id,
178+
)
179+
async for frame in decoder:
180+
emitter.push(frame)
181+
emitter.flush()
182+
await push_task
183+
finally:
184+
await utils.aio.gracefully_cancel(push_task)
187185

188186
except asyncio.TimeoutError as e:
189187
raise APITimeoutError() from e

0 commit comments

Comments
 (0)