Skip to content

Commit b62bb7c

Browse files
committed
add tests, thx gpt
1 parent a765768 commit b62bb7c

File tree

4 files changed

+102
-24
lines changed

4 files changed

+102
-24
lines changed

dev-requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ auditwheel; sys_platform == 'linux'
99
cibuildwheel
1010

1111
pytest
12+
13+
matplotlib

livekit-rtc/livekit/rtc/audio_mixer.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from typing import AsyncIterator, Optional
21
import asyncio
32
import numpy as np
43
import contextlib
54
from dataclasses import dataclass
5+
from typing import AsyncIterator, Optional
66
from .audio_frame import AudioFrame
77
from .log import logger
88

9-
109
_Stream = AsyncIterator[AudioFrame]
1110

1211

@@ -58,7 +57,9 @@ def __init__(
5857
self._chunk_size: int = blocksize if blocksize > 0 else int(sample_rate // 10)
5958
self._stream_timeout_ms: int = stream_timeout_ms
6059
self._queue: asyncio.Queue[Optional[AudioFrame]] = asyncio.Queue(maxsize=capacity)
61-
self._closed: bool = False
60+
# _ending signals that no new streams will be added,
61+
# but we continue processing until all streams are exhausted.
62+
self._ending: bool = False
6263
self._mixer_task: asyncio.Task = asyncio.create_task(self._mixer())
6364

6465
def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
@@ -87,25 +88,10 @@ def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None:
8788
self._streams.discard(stream)
8889
self._buffers.pop(stream, None)
8990

90-
async def __aiter__(self) -> "AudioMixer":
91-
"""
92-
Return the async iterator interface for the mixer.
93-
94-
Returns:
95-
AudioMixer: The mixer itself, as it implements the async iterator protocol.
96-
"""
91+
def __aiter__(self) -> "AudioMixer":
9792
return self
9893

9994
async def __anext__(self) -> AudioFrame:
100-
"""
101-
Retrieve the next mixed AudioFrame from the output queue.
102-
103-
Returns:
104-
AudioFrame: The next mixed audio frame.
105-
106-
Raises:
107-
StopAsyncIteration: When the mixer is closed and no more frames are available.
108-
"""
10995
item = await self._queue.get()
11096
if item is None:
11197
raise StopAsyncIteration
@@ -117,7 +103,7 @@ async def aclose(self) -> None:
117103
118104
This cancels the mixing task, and any unconsumed output in the queue may be dropped.
119105
"""
120-
self._closed = True
106+
self._ending = True
121107
self._mixer_task.cancel()
122108
with contextlib.suppress(asyncio.CancelledError):
123109
await self._mixer_task
@@ -129,13 +115,18 @@ def end_input(self) -> None:
129115
This method marks the mixer as closed so that it flushes any remaining buffered output before ending.
130116
Note that existing streams will still be processed until exhausted.
131117
"""
132-
self._closed = True
118+
self._ending = True
133119

134120
async def _mixer(self) -> None:
135-
while not self._closed:
121+
while True:
122+
# If we're in ending mode and there are no more streams, exit.
123+
if self._ending and not self._streams:
124+
break
125+
136126
if not self._streams:
137127
await asyncio.sleep(0.01)
138128
continue
129+
139130
tasks = [
140131
self._get_contribution(
141132
stream,
@@ -185,7 +176,7 @@ async def _get_contribution(
185176
stream.__anext__(), timeout=self._stream_timeout_ms / 1000
186177
)
187178
except asyncio.TimeoutError:
188-
logger.warning(f"AudioMixer: stream {stream} timeout, ignoring`")
179+
logger.warning(f"AudioMixer: stream {stream} timeout, ignoring")
189180
break
190181
except StopAsyncIteration:
191182
exhausted = True

livekit-rtc/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def finalize_options(self):
5858
license="Apache-2.0",
5959
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
6060
python_requires=">=3.9.0",
61-
install_requires=["protobuf>=4.25.0", "types-protobuf>=3", "aiofiles>=24"],
61+
install_requires=["protobuf>=4.25.0", "types-protobuf>=3", "aiofiles>=24", "numpy>=1.26"],
6262
package_data={
6363
"livekit.rtc": ["_proto/*.py", "py.typed", "*.pyi", "**/*.pyi"],
6464
"livekit.rtc.resources": ["*.so", "*.dylib", "*.dll", "LICENSE.md", "*.h"],

livekit-rtc/tests/test_mixer.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import AsyncIterator
2+
import numpy as np
3+
import pytest
4+
import matplotlib.pyplot as plt
5+
6+
from livekit.rtc import AudioMixer
7+
from livekit.rtc.audio_frame import AudioFrame
8+
9+
SAMPLE_RATE = 48000
10+
# Use 100ms blocks (i.e. 1600 samples per frame)
11+
BLOCKSIZE = SAMPLE_RATE // 10
12+
13+
14+
async def sine_wave_generator(freq: float, duration: float) -> AsyncIterator[AudioFrame]:
15+
total_frames = int((duration * SAMPLE_RATE) // BLOCKSIZE)
16+
t_frame = np.arange(BLOCKSIZE) / SAMPLE_RATE
17+
for i in range(total_frames):
18+
# Shift the time for each frame so that the sine wave is continuous
19+
t = t_frame + i * BLOCKSIZE / SAMPLE_RATE
20+
# Create a sine wave with amplitude 0.3 (to avoid clipping when summing)
21+
signal = 0.3 * np.sin(2 * np.pi * freq * t)
22+
# Convert from float [-0.5, 0.5] to int16 values
23+
signal_int16 = np.int16(signal * 32767)
24+
frame = AudioFrame(
25+
signal_int16.tobytes(),
26+
SAMPLE_RATE,
27+
1,
28+
BLOCKSIZE,
29+
)
30+
yield frame
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_mixer_two_sine_waves():
35+
"""
36+
Test that mixing two sine waves (440Hz and 880Hz) produces an output
37+
containing both frequency components.
38+
"""
39+
duration = 1.0
40+
mixer = AudioMixer(
41+
sample_rate=SAMPLE_RATE,
42+
num_channels=1,
43+
blocksize=BLOCKSIZE,
44+
stream_timeout_ms=100,
45+
capacity=100,
46+
)
47+
stream1 = sine_wave_generator(440, duration)
48+
stream2 = sine_wave_generator(880, duration)
49+
mixer.add_stream(stream1)
50+
mixer.add_stream(stream2)
51+
mixer.end_input()
52+
53+
mixed_signals = []
54+
async for frame in mixer:
55+
data = np.frombuffer(frame.data.tobytes(), dtype=np.int16)
56+
mixed_signals.append(data)
57+
58+
await mixer.aclose()
59+
60+
if not mixed_signals:
61+
pytest.fail("No frames were produced by the mixer.")
62+
63+
mixed_signal = np.concatenate(mixed_signals)
64+
65+
plt.figure(figsize=(10, 4))
66+
plt.plot(mixed_signal[:1000]) # plot 1000
67+
plt.title("Mixed Signal")
68+
plt.xlabel("Sample")
69+
plt.ylabel("Amplitude")
70+
plt.show()
71+
72+
# Use FFT to analyze frequency components.
73+
fft = np.fft.rfft(mixed_signal)
74+
freqs = np.fft.rfftfreq(len(mixed_signal), 1 / SAMPLE_RATE)
75+
magnitude = np.abs(fft)
76+
77+
# Identify peak frequencies. We'll pick the 5 highest peaks.
78+
peak_indices = np.argsort(magnitude)[-5:]
79+
peak_freqs = freqs[peak_indices]
80+
81+
print("Peak frequencies:", peak_freqs)
82+
83+
# Assert that the peaks include 440Hz and 880Hz (with a tolerance of ±5 Hz)
84+
assert any(np.isclose(peak_freqs, 440, atol=5)), f"Expected 440 Hz in peaks, got: {peak_freqs}"
85+
assert any(np.isclose(peak_freqs, 880, atol=5)), f"Expected 880 Hz in peaks, got: {peak_freqs}"

0 commit comments

Comments
 (0)