Skip to content

Commit a844b3c

Browse files
committed
initial commit
1 parent e9e87e6 commit a844b3c

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from typing import AsyncIterator, Optional, Set, Dict
2+
import asyncio
3+
import numpy as np
4+
import contextlib
5+
from dataclasses import dataclass
6+
from .audio_frame import AudioFrame
7+
from .log import logger
8+
9+
10+
_Stream = AsyncIterator[AudioFrame]
11+
12+
13+
@dataclass
14+
class _Contribution:
15+
stream: _Stream
16+
data: np.ndarray
17+
buffer: np.ndarray
18+
had_data: bool
19+
exhausted: bool
20+
21+
22+
class AudioMixer:
23+
def __init__(
24+
self,
25+
*,
26+
sample_rate: int,
27+
num_channels: int,
28+
blocksize: int = 0,
29+
stream_timeout_ms: int = 100,
30+
capacity: int = 100,
31+
) -> None:
32+
"""
33+
Initialize the AudioMixer.
34+
35+
The mixer accepts multiple async audio streams and mixes them into a single output stream.
36+
Each output frame is generated with a fixed chunk size determined by the blocksize (in samples).
37+
If blocksize is not provided (or 0), it defaults to 100ms.
38+
39+
Each input stream is processed in parallel, accumulating audio data until at least one chunk
40+
of samples is available. If an input stream does not provide data within the specified timeout,
41+
a warning is logged. The mixer can be closed immediately
42+
(dropping unconsumed frames) or allowed to flush remaining data using end_input().
43+
44+
Args:
45+
sample_rate (int): The audio sample rate in Hz.
46+
num_channels (int): The number of audio channels.
47+
blocksize (int, optional): The size of the audio block (in samples) for mixing. If not provided,
48+
defaults to sample_rate // 10.
49+
stream_timeout_ms (int, optional): The maximum wait time in milliseconds for each stream to provide
50+
audio data before timing out. Defaults to 100 ms.
51+
capacity (int, optional): The maximum number of mixed frames to store in the output queue.
52+
Defaults to 100.
53+
"""
54+
self._streams: set[_Stream] = set()
55+
self._buffers: dict[_Stream, np.ndarray] = {}
56+
self._sample_rate: int = sample_rate
57+
self._num_channels: int = num_channels
58+
self._chunk_size: int = blocksize if blocksize > 0 else int(sample_rate // 10)
59+
self._stream_timeout_ms: int = stream_timeout_ms
60+
self._queue: asyncio.Queue[Optional[AudioFrame]] = asyncio.Queue(maxsize=capacity)
61+
self._closed: bool = False
62+
self._mixer_task: asyncio.Task = asyncio.create_task(self._mixer())
63+
64+
def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
65+
"""
66+
Add an audio stream to the mixer.
67+
68+
The stream is added to the internal set of streams and an empty buffer is initialized for it,
69+
if not already present.
70+
71+
Args:
72+
stream (AsyncIterator[AudioFrame]): An async iterator that produces AudioFrame objects.
73+
"""
74+
self._streams.add(stream)
75+
if stream not in self._buffers:
76+
self._buffers[stream] = np.empty((0, self._num_channels), dtype=np.int16)
77+
78+
def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
79+
"""
80+
Remove an audio stream from the mixer.
81+
82+
This method removes the specified stream and its associated buffer from the mixer.
83+
84+
Args:
85+
stream (AsyncIterator[AudioFrame]): The audio stream to remove.
86+
"""
87+
self._streams.discard(stream)
88+
self._buffers.pop(stream, None)
89+
90+
async def __aiter__(self) -> "AudioMixer":
91+
"""
92+
Return the async iterator interface for the mixer.
93+
94+
Returns:
95+
AudioMixer: The mixer itself, as it implements the async iterator protocol.
96+
"""
97+
return self
98+
99+
async def __anext__(self) -> AudioFrame:
100+
"""
101+
Retrieve the next mixed AudioFrame from the output queue.
102+
103+
Returns:
104+
AudioFrame: The next mixed audio frame.
105+
106+
Raises:
107+
StopAsyncIteration: When the mixer is closed and no more frames are available.
108+
"""
109+
item = await self._queue.get()
110+
if item is None:
111+
raise StopAsyncIteration
112+
return item
113+
114+
async def close(self) -> None:
115+
"""
116+
Immediately stop mixing and close the mixer.
117+
118+
This cancels the mixing task, and any unconsumed output in the queue may be dropped.
119+
"""
120+
self._closed = True
121+
self._mixer_task.cancel()
122+
with contextlib.suppress(asyncio.CancelledError):
123+
await self._mixer_task
124+
125+
def end_input(self) -> None:
126+
"""
127+
Signal that no more streams will be added.
128+
129+
This method marks the mixer as closed so that it flushes any remaining buffered output before ending.
130+
Note that existing streams will still be processed until exhausted.
131+
"""
132+
self._closed = True
133+
134+
async def _mixer(self) -> None:
135+
while not self._closed:
136+
if not self._streams:
137+
await asyncio.sleep(0.01)
138+
continue
139+
tasks = [
140+
self._get_contribution(
141+
stream,
142+
self._buffers.get(stream, np.empty((0, self._num_channels), dtype=np.int16)),
143+
)
144+
for stream in list(self._streams)
145+
]
146+
results = await asyncio.gather(*tasks, return_exceptions=True)
147+
contributions = []
148+
any_data = False
149+
removals = []
150+
for contrib in results:
151+
if not isinstance(contrib, _Contribution):
152+
continue
153+
154+
contributions.append(contrib.data.astype(np.float32))
155+
self._buffers[contrib.stream] = contrib.buffer
156+
if contrib.had_data:
157+
any_data = True
158+
if contrib.exhausted and contrib.buffer.shape[0] == 0:
159+
removals.append(contrib.stream)
160+
161+
for stream in removals:
162+
self.remove_stream(stream)
163+
164+
if not any_data:
165+
await asyncio.sleep(0.001)
166+
continue
167+
168+
mixed = np.sum(np.stack(contributions, axis=0), axis=0)
169+
mixed = np.clip(mixed, -32768, 32767).astype(np.int16)
170+
frame = AudioFrame(
171+
mixed.tobytes(), self._sample_rate, self._num_channels, self._chunk_size
172+
)
173+
await self._queue.put(frame)
174+
175+
await self._queue.put(None)
176+
177+
async def _get_contribution(
178+
self, stream: AsyncIterator[AudioFrame], buf: np.ndarray
179+
) -> _Contribution:
180+
had_data = buf.shape[0] > 0
181+
exhausted = False
182+
while buf.shape[0] < self._chunk_size and not exhausted:
183+
try:
184+
frame = await asyncio.wait_for(
185+
stream.__anext__(), timeout=self._stream_timeout_ms / 1000
186+
)
187+
except asyncio.TimeoutError:
188+
logger.warning(f"AudioMixer: stream {stream} timeout, ignoring`")
189+
break
190+
except StopAsyncIteration:
191+
exhausted = True
192+
break
193+
new_data = np.frombuffer(frame.data.tobytes(), dtype=np.int16).reshape(
194+
-1, self._num_channels
195+
)
196+
buf = np.concatenate((buf, new_data), axis=0) if buf.size else new_data
197+
had_data = True
198+
if buf.shape[0] >= self._chunk_size:
199+
contrib, buf = buf[: self._chunk_size], buf[self._chunk_size :]
200+
else:
201+
pad = np.zeros((self._chunk_size - buf.shape[0], self._num_channels), dtype=np.int16)
202+
contrib, buf = (
203+
np.concatenate((buf, pad), axis=0),
204+
np.empty((0, self._num_channels), dtype=np.int16),
205+
)
206+
return _Contribution(stream, contrib, buf, had_data, exhausted)

0 commit comments

Comments
 (0)