Skip to content

Commit 5324148

Browse files
authored
add rtc.AudioMixer (#400)
1 parent e9e87e6 commit 5324148

File tree

6 files changed

+297
-2
lines changed

6 files changed

+297
-2
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ jobs:
2323
- name: Run tests
2424
run: |
2525
python3 ./livekit-rtc/rust-sdks/download_ffi.py --output livekit-rtc/livekit/rtc/resources
26-
pip3 install pytest ./livekit-protocol ./livekit-api ./livekit-rtc pydantic numpy
26+
pip3 install ./livekit-protocol ./livekit-api ./livekit-rtc
27+
pip3 install -r dev-requirements.txt
2728
pytest . --ignore=livekit-rtc/rust-sdks

dev-requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,8 @@ auditwheel; sys_platform == 'linux'
99
cibuildwheel
1010

1111
pytest
12+
pytest-asyncio
13+
14+
matplotlib
15+
pydantic
16+
numpy

livekit-rtc/livekit/rtc/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from .video_source import VideoSource
7777
from .video_stream import VideoFrameEvent, VideoStream
7878
from .audio_resampler import AudioResampler, AudioResamplerQuality
79+
from .audio_mixer import AudioMixer
7980
from .apm import AudioProcessingModule
8081
from .utils import combine_audio_frames
8182
from .rpc import RpcError, RpcInvocationData
@@ -148,6 +149,7 @@
148149
"VideoFrameEvent",
149150
"VideoSource",
150151
"VideoStream",
152+
"AudioMixer",
151153
"AudioResampler",
152154
"AudioResamplerQuality",
153155
"RpcError",
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import asyncio
2+
import numpy as np
3+
import contextlib
4+
from dataclasses import dataclass
5+
from typing import AsyncIterator, Optional
6+
from .audio_frame import AudioFrame
7+
from .log import logger
8+
9+
_Stream = AsyncIterator[AudioFrame]
10+
11+
12+
@dataclass
13+
class _Contribution:
14+
stream: _Stream
15+
data: np.ndarray
16+
buffer: np.ndarray
17+
had_data: bool
18+
exhausted: bool
19+
20+
21+
class AudioMixer:
22+
def __init__(
23+
self,
24+
sample_rate: int,
25+
num_channels: int,
26+
*,
27+
blocksize: int = 0,
28+
stream_timeout_ms: int = 100,
29+
capacity: int = 100,
30+
) -> None:
31+
"""
32+
Initialize the AudioMixer.
33+
34+
The mixer accepts multiple async audio streams and mixes them into a single output stream.
35+
Each output frame is generated with a fixed chunk size determined by the blocksize (in samples).
36+
If blocksize is not provided (or 0), it defaults to 100ms.
37+
38+
Each input stream is processed in parallel, accumulating audio data until at least one chunk
39+
of samples is available. If an input stream does not provide data within the specified timeout,
40+
a warning is logged. The mixer can be closed immediately
41+
(dropping unconsumed frames) or allowed to flush remaining data using end_input().
42+
43+
Args:
44+
sample_rate (int): The audio sample rate in Hz.
45+
num_channels (int): The number of audio channels.
46+
blocksize (int, optional): The size of the audio block (in samples) for mixing. If not provided,
47+
defaults to sample_rate // 10.
48+
stream_timeout_ms (int, optional): The maximum wait time in milliseconds for each stream to provide
49+
audio data before timing out. Defaults to 100 ms.
50+
capacity (int, optional): The maximum number of mixed frames to store in the output queue.
51+
Defaults to 100.
52+
"""
53+
self._streams: set[_Stream] = set()
54+
self._buffers: dict[_Stream, np.ndarray] = {}
55+
self._sample_rate: int = sample_rate
56+
self._num_channels: int = num_channels
57+
self._chunk_size: int = blocksize if blocksize > 0 else int(sample_rate // 10)
58+
self._stream_timeout_ms: int = stream_timeout_ms
59+
self._queue: asyncio.Queue[Optional[AudioFrame]] = asyncio.Queue(maxsize=capacity)
60+
# _ending signals that no new streams will be added,
61+
# but we continue processing until all streams are exhausted.
62+
self._ending: bool = False
63+
self._mixer_task: asyncio.Task = asyncio.create_task(self._mixer())
64+
65+
def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
66+
"""
67+
Add an audio stream to the mixer.
68+
69+
The stream is added to the internal set of streams and an empty buffer is initialized for it,
70+
if not already present.
71+
72+
Args:
73+
stream (AsyncIterator[AudioFrame]): An async iterator that produces AudioFrame objects.
74+
"""
75+
if self._ending:
76+
raise RuntimeError("Cannot add stream after mixer has been closed")
77+
78+
self._streams.add(stream)
79+
if stream not in self._buffers:
80+
self._buffers[stream] = np.empty((0, self._num_channels), dtype=np.int16)
81+
82+
def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
83+
"""
84+
Remove an audio stream from the mixer.
85+
86+
This method removes the specified stream and its associated buffer from the mixer.
87+
88+
Args:
89+
stream (AsyncIterator[AudioFrame]): The audio stream to remove.
90+
"""
91+
self._streams.discard(stream)
92+
self._buffers.pop(stream, None)
93+
94+
def __aiter__(self) -> "AudioMixer":
95+
return self
96+
97+
async def __anext__(self) -> AudioFrame:
98+
item = await self._queue.get()
99+
if item is None:
100+
raise StopAsyncIteration
101+
return item
102+
103+
async def aclose(self) -> None:
104+
"""
105+
Immediately stop mixing and close the mixer.
106+
107+
This cancels the mixing task, and any unconsumed output in the queue may be dropped.
108+
"""
109+
self._ending = True
110+
self._mixer_task.cancel()
111+
with contextlib.suppress(asyncio.CancelledError):
112+
await self._mixer_task
113+
114+
def end_input(self) -> None:
115+
"""
116+
Signal that no more streams will be added.
117+
118+
This method marks the mixer as closed so that it flushes any remaining buffered output before ending.
119+
Note that existing streams will still be processed until exhausted.
120+
"""
121+
self._ending = True
122+
123+
async def _mixer(self) -> None:
124+
while True:
125+
# If we're in ending mode and there are no more streams, exit.
126+
if self._ending and not self._streams:
127+
break
128+
129+
if not self._streams:
130+
await asyncio.sleep(0.01)
131+
continue
132+
133+
tasks = [
134+
self._get_contribution(
135+
stream,
136+
self._buffers.get(stream, np.empty((0, self._num_channels), dtype=np.int16)),
137+
)
138+
for stream in list(self._streams)
139+
]
140+
results = await asyncio.gather(*tasks, return_exceptions=True)
141+
contributions = []
142+
any_data = False
143+
removals = []
144+
for contrib in results:
145+
if not isinstance(contrib, _Contribution):
146+
continue
147+
148+
contributions.append(contrib.data.astype(np.float32))
149+
self._buffers[contrib.stream] = contrib.buffer
150+
if contrib.had_data:
151+
any_data = True
152+
if contrib.exhausted and contrib.buffer.shape[0] == 0:
153+
removals.append(contrib.stream)
154+
155+
for stream in removals:
156+
self.remove_stream(stream)
157+
158+
if not any_data:
159+
await asyncio.sleep(0.001)
160+
continue
161+
162+
mixed = np.sum(np.stack(contributions, axis=0), axis=0)
163+
mixed = np.clip(mixed, -32768, 32767).astype(np.int16)
164+
frame = AudioFrame(
165+
mixed.tobytes(), self._sample_rate, self._num_channels, self._chunk_size
166+
)
167+
await self._queue.put(frame)
168+
169+
await self._queue.put(None)
170+
171+
async def _get_contribution(
172+
self, stream: AsyncIterator[AudioFrame], buf: np.ndarray
173+
) -> _Contribution:
174+
had_data = buf.shape[0] > 0
175+
exhausted = False
176+
while buf.shape[0] < self._chunk_size and not exhausted:
177+
try:
178+
frame = await asyncio.wait_for(
179+
stream.__anext__(), timeout=self._stream_timeout_ms / 1000
180+
)
181+
except asyncio.TimeoutError:
182+
logger.warning(f"AudioMixer: stream {stream} timeout, ignoring")
183+
break
184+
except StopAsyncIteration:
185+
exhausted = True
186+
break
187+
new_data = np.frombuffer(frame.data.tobytes(), dtype=np.int16).reshape(
188+
-1, self._num_channels
189+
)
190+
buf = np.concatenate((buf, new_data), axis=0) if buf.size else new_data
191+
had_data = True
192+
if buf.shape[0] >= self._chunk_size:
193+
contrib, buf = buf[: self._chunk_size], buf[self._chunk_size :]
194+
else:
195+
pad = np.zeros((self._chunk_size - buf.shape[0], self._num_channels), dtype=np.int16)
196+
contrib, buf = (
197+
np.concatenate((buf, pad), axis=0),
198+
np.empty((0, self._num_channels), dtype=np.int16),
199+
)
200+
return _Contribution(stream, contrib, buf, had_data, exhausted)

livekit-rtc/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def finalize_options(self):
5858
license="Apache-2.0",
5959
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
6060
python_requires=">=3.9.0",
61-
install_requires=["protobuf>=4.25.0", "types-protobuf>=3", "aiofiles>=24"],
61+
install_requires=["protobuf>=4.25.0", "types-protobuf>=3", "aiofiles>=24", "numpy>=1.26"],
6262
package_data={
6363
"livekit.rtc": ["_proto/*.py", "py.typed", "*.pyi", "**/*.pyi"],
6464
"livekit.rtc.resources": ["*.so", "*.dylib", "*.dll", "LICENSE.md", "*.h"],

livekit-rtc/tests/test_mixer.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# type: ignore
2+
3+
from typing import AsyncIterator
4+
import numpy as np
5+
import pytest
6+
import matplotlib.pyplot as plt
7+
8+
from livekit.rtc import AudioMixer
9+
from livekit.rtc.audio_frame import AudioFrame
10+
11+
SAMPLE_RATE = 48000
12+
# Use 100ms blocks (i.e. 1600 samples per frame)
13+
BLOCKSIZE = SAMPLE_RATE // 10
14+
15+
16+
async def sine_wave_generator(freq: float, duration: float) -> AsyncIterator[AudioFrame]:
17+
total_frames = int((duration * SAMPLE_RATE) // BLOCKSIZE)
18+
t_frame = np.arange(BLOCKSIZE) / SAMPLE_RATE
19+
for i in range(total_frames):
20+
# Shift the time for each frame so that the sine wave is continuous
21+
t = t_frame + i * BLOCKSIZE / SAMPLE_RATE
22+
# Create a sine wave with amplitude 0.3 (to avoid clipping when summing)
23+
signal = 0.3 * np.sin(2 * np.pi * freq * t)
24+
# Convert from float [-0.5, 0.5] to int16 values
25+
signal_int16 = np.int16(signal * 32767)
26+
frame = AudioFrame(
27+
signal_int16.tobytes(),
28+
SAMPLE_RATE,
29+
1,
30+
BLOCKSIZE,
31+
)
32+
yield frame
33+
34+
35+
@pytest.mark.asyncio
36+
async def test_mixer_two_sine_waves():
37+
"""
38+
Test that mixing two sine waves (440Hz and 880Hz) produces an output
39+
containing both frequency components.
40+
"""
41+
duration = 1.0
42+
mixer = AudioMixer(
43+
sample_rate=SAMPLE_RATE,
44+
num_channels=1,
45+
blocksize=BLOCKSIZE,
46+
stream_timeout_ms=100,
47+
capacity=100,
48+
)
49+
stream1 = sine_wave_generator(440, duration)
50+
stream2 = sine_wave_generator(880, duration)
51+
mixer.add_stream(stream1)
52+
mixer.add_stream(stream2)
53+
mixer.end_input()
54+
55+
mixed_signals = []
56+
async for frame in mixer:
57+
data = np.frombuffer(frame.data.tobytes(), dtype=np.int16)
58+
mixed_signals.append(data)
59+
60+
await mixer.aclose()
61+
62+
if not mixed_signals:
63+
pytest.fail("No frames were produced by the mixer.")
64+
65+
mixed_signal = np.concatenate(mixed_signals)
66+
67+
plt.figure(figsize=(10, 4))
68+
plt.plot(mixed_signal[:1000]) # plot 1000
69+
plt.title("Mixed Signal")
70+
plt.xlabel("Sample")
71+
plt.ylabel("Amplitude")
72+
plt.show()
73+
74+
# Use FFT to analyze frequency components.
75+
fft = np.fft.rfft(mixed_signal)
76+
freqs = np.fft.rfftfreq(len(mixed_signal), 1 / SAMPLE_RATE)
77+
magnitude = np.abs(fft)
78+
79+
# Identify peak frequencies. We'll pick the 5 highest peaks.
80+
peak_indices = np.argsort(magnitude)[-5:]
81+
peak_freqs = freqs[peak_indices]
82+
83+
print("Peak frequencies:", peak_freqs)
84+
85+
# Assert that the peaks include 440Hz and 880Hz (with a tolerance of ±5 Hz)
86+
assert any(np.isclose(peak_freqs, 440, atol=5)), f"Expected 440 Hz in peaks, got: {peak_freqs}"
87+
assert any(np.isclose(peak_freqs, 880, atol=5)), f"Expected 880 Hz in peaks, got: {peak_freqs}"

0 commit comments

Comments
 (0)