Skip to content

Commit 5befa5f

Browse files
committed
add lock for AudioMixer to sync access to stream
1 parent a832411 commit 5befa5f

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

livekit-rtc/livekit/rtc/audio_mixer.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
capacity (int, optional): The maximum number of mixed frames to store in the output queue.
5151
Defaults to 100.
5252
"""
53-
self._streams: set[_Stream] = set()
53+
self._streams: dict[_Stream, asyncio.Lock] = {}
5454
self._buffers: dict[_Stream, np.ndarray] = {}
5555
self._sample_rate: int = sample_rate
5656
self._num_channels: int = num_channels
@@ -62,7 +62,7 @@ def __init__(
6262
self._ending: bool = False
6363
self._mixer_task: asyncio.Task = asyncio.create_task(self._mixer())
6464

65-
def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
65+
def add_stream(self, stream: AsyncIterator[AudioFrame]) -> asyncio.Lock:
6666
"""
6767
Add an audio stream to the mixer.
6868
@@ -71,13 +71,17 @@ def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
7171
7272
Args:
7373
stream (AsyncIterator[AudioFrame]): An async iterator that produces AudioFrame objects.
74+
75+
Returns:
76+
asyncio.Lock: A lock that can be used to synchronize access to the stream.
7477
"""
7578
if self._ending:
7679
raise RuntimeError("Cannot add stream after mixer has been closed")
7780

78-
self._streams.add(stream)
81+
self._streams[stream] = asyncio.Lock()
7982
if stream not in self._buffers:
8083
self._buffers[stream] = np.empty((0, self._num_channels), dtype=np.int16)
84+
return self._streams[stream]
8185

8286
def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
8387
"""
@@ -88,7 +92,7 @@ def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
8892
Args:
8993
stream (AsyncIterator[AudioFrame]): The audio stream to remove.
9094
"""
91-
self._streams.discard(stream)
95+
self._streams.pop(stream, None)
9296
self._buffers.pop(stream, None)
9397

9498
def __aiter__(self) -> "AudioMixer":
@@ -133,9 +137,10 @@ async def _mixer(self) -> None:
133137
tasks = [
134138
self._get_contribution(
135139
stream,
140+
lock,
136141
self._buffers.get(stream, np.empty((0, self._num_channels), dtype=np.int16)),
137142
)
138-
for stream in list(self._streams)
143+
for stream, lock in self._streams.items()
139144
]
140145
results = await asyncio.gather(*tasks, return_exceptions=True)
141146
contributions = []
@@ -169,15 +174,18 @@ async def _mixer(self) -> None:
169174
await self._queue.put(None)
170175

171176
async def _get_contribution(
172-
self, stream: AsyncIterator[AudioFrame], buf: np.ndarray
177+
self, stream: AsyncIterator[AudioFrame], lock: asyncio.Lock, buf: np.ndarray
173178
) -> _Contribution:
174179
had_data = buf.shape[0] > 0
175180
exhausted = False
181+
182+
async def _get_frame() -> AudioFrame:
183+
async with lock:
184+
return await stream.__anext__()
185+
176186
while buf.shape[0] < self._chunk_size and not exhausted:
177187
try:
178-
frame = await asyncio.wait_for(
179-
stream.__anext__(), timeout=self._stream_timeout_ms / 1000
180-
)
188+
frame = await asyncio.wait_for(_get_frame(), timeout=self._stream_timeout_ms / 1000)
181189
except asyncio.TimeoutError:
182190
logger.warning(f"AudioMixer: stream {stream} timeout, ignoring")
183191
break

0 commit comments

Comments
 (0)