Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/3645.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Fixed `SmallWebRTCTransport` not respecting `TransportParams.audio_out_10ms_chunks` parameter. The transport was hardcoded to produce 10ms audio frames regardless of the configured chunk size.
40 changes: 28 additions & 12 deletions src/pipecat/transports/smallwebrtc/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,29 @@ class RawAudioTrack(AudioStreamTrack):
supporting queued audio data with proper synchronization.
"""

def __init__(self, sample_rate):
def __init__(self, sample_rate: int, num_10ms_chunks: int = 1):
"""Initialize the raw audio track.

Args:
sample_rate: The audio sample rate in Hz.
num_10ms_chunks: Number of 10ms chunks per output frame (default 1).

Raises:
ValueError: If num_10ms_chunks is not a positive integer.
"""
if num_10ms_chunks < 1:
raise ValueError(f"num_10ms_chunks must be a positive integer, got {num_10ms_chunks}")
super().__init__()
self._sample_rate = sample_rate
self._num_10ms_chunks = num_10ms_chunks
self._samples_per_10ms = sample_rate * 10 // 1000
self._bytes_per_10ms = self._samples_per_10ms * 2 # 16-bit (2 bytes per sample)
# Calculate chunk size based on num_10ms_chunks
self._samples_per_chunk = self._samples_per_10ms * num_10ms_chunks
self._bytes_per_chunk = self._bytes_per_10ms * num_10ms_chunks
self._timestamp = 0
self._start = time.time()
# Queue of (bytes, future), broken into 10ms sub chunks as needed
# Queue of (bytes, future), broken into configured chunk sizes as needed
self._chunk_queue = deque()

def add_audio_bytes(self, audio_bytes: bytes):
Expand All @@ -103,17 +113,20 @@ def add_audio_bytes(self, audio_bytes: bytes):
A Future that completes when the data is processed.

Raises:
ValueError: If audio bytes are not a multiple of 10ms size.
ValueError: If audio bytes are not a multiple of the configured chunk size.
"""
if len(audio_bytes) % self._bytes_per_10ms != 0:
raise ValueError("Audio bytes must be a multiple of 10ms size.")
if len(audio_bytes) % self._bytes_per_chunk != 0:
raise ValueError(
f"Audio bytes must be a multiple of {self._num_10ms_chunks * 10}ms size "
f"({self._bytes_per_chunk} bytes)."
)
future = asyncio.get_running_loop().create_future()

# Break input into 10ms chunks
for i in range(0, len(audio_bytes), self._bytes_per_10ms):
chunk = audio_bytes[i : i + self._bytes_per_10ms]
# Break input into configured chunk sizes
for i in range(0, len(audio_bytes), self._bytes_per_chunk):
chunk = audio_bytes[i : i + self._bytes_per_chunk]
# Only the last chunk carries the future to be resolved once fully consumed
fut = future if i + self._bytes_per_10ms >= len(audio_bytes) else None
fut = future if i + self._bytes_per_chunk >= len(audio_bytes) else None
self._chunk_queue.append((chunk, fut))

return future
Expand All @@ -135,7 +148,7 @@ async def recv(self):
if future and not future.done():
future.set_result(True)
else:
chunk = bytes(self._bytes_per_10ms) # silence
chunk = bytes(self._bytes_per_chunk) # silence

# Convert the byte data to an ndarray of int16 samples
samples = np.frombuffer(chunk, dtype=np.int16)
Expand All @@ -145,7 +158,7 @@ async def recv(self):
frame.sample_rate = self._sample_rate
frame.pts = self._timestamp
frame.time_base = fractions.Fraction(1, self._sample_rate)
self._timestamp += self._samples_per_10ms
self._timestamp += self._samples_per_chunk
return frame


Expand Down Expand Up @@ -493,7 +506,10 @@ async def _handle_client_connected(self):
self._video_input_track = self._webrtc_connection.video_input_track()
self._screen_video_track = self._webrtc_connection.screen_video_input_track()
if self._params.audio_out_enabled:
self._audio_output_track = RawAudioTrack(sample_rate=self._out_sample_rate)
self._audio_output_track = RawAudioTrack(
sample_rate=self._out_sample_rate,
num_10ms_chunks=self._params.audio_out_10ms_chunks,
)
self._webrtc_connection.replace_audio_track(self._audio_output_track)

if self._params.video_out_enabled:
Expand Down
170 changes: 170 additions & 0 deletions tests/test_smallwebrtc_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

# pyright: reportConstantRedefinition=false
# pyright: reportPrivateUsage=false, reportUnknownMemberType=false
# pyright: reportUnknownArgumentType=false, reportUnknownVariableType=false
# pyright: reportOperatorIssue=false
# pyright: reportOptionalCall=false

import unittest
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pipecat.transports.smallwebrtc.transport import RawAudioTrack

try:
from pipecat.transports.smallwebrtc.transport import RawAudioTrack

WEBRTC_AVAILABLE = True
except (ImportError, Exception):
WEBRTC_AVAILABLE = False
RawAudioTrack = None # type: ignore[misc,assignment]


@unittest.skipUnless(WEBRTC_AVAILABLE, "webrtc dependencies not installed")
class TestRawAudioTrack(unittest.IsolatedAsyncioTestCase):
"""Tests for the RawAudioTrack class."""

def test_default_chunk_size_is_10ms(self):
"""Test that default chunk size is 10ms (num_10ms_chunks=1)."""
sample_rate = 16000
track = RawAudioTrack(sample_rate=sample_rate)

# 10ms at 16kHz = 160 samples, 2 bytes per sample = 320 bytes
expected_bytes = int(sample_rate * 10 / 1000) * 2
self.assertEqual(track._bytes_per_chunk, expected_bytes)
self.assertEqual(track._bytes_per_chunk, 320)

def test_custom_chunk_size_40ms(self):
"""Test that num_10ms_chunks=4 produces 40ms chunks."""
sample_rate = 16000
track = RawAudioTrack(sample_rate=sample_rate, num_10ms_chunks=4)

# 40ms at 16kHz = 640 samples, 2 bytes per sample = 1280 bytes
expected_bytes = int(sample_rate * 40 / 1000) * 2
self.assertEqual(track._bytes_per_chunk, expected_bytes)
self.assertEqual(track._bytes_per_chunk, 1280)

def test_custom_chunk_size_20ms(self):
"""Test that num_10ms_chunks=2 produces 20ms chunks."""
sample_rate = 16000
track = RawAudioTrack(sample_rate=sample_rate, num_10ms_chunks=2)

# 20ms at 16kHz = 320 samples, 2 bytes per sample = 640 bytes
expected_bytes = int(sample_rate * 20 / 1000) * 2
self.assertEqual(track._bytes_per_chunk, expected_bytes)
self.assertEqual(track._bytes_per_chunk, 640)

async def test_add_audio_bytes_queues_correct_chunks(self):
"""Test that add_audio_bytes breaks audio into correct chunk sizes."""
sample_rate = 16000
num_chunks = 4 # 40ms
track = RawAudioTrack(sample_rate=sample_rate, num_10ms_chunks=num_chunks)

# Create 80ms of audio (2 chunks of 40ms each)
audio_bytes = bytes(track._bytes_per_chunk * 2)
track.add_audio_bytes(audio_bytes)

# Should have exactly 2 chunks in the queue
self.assertEqual(len(track._chunk_queue), 2)

# Each chunk should be the correct size
chunk1, _ = track._chunk_queue[0]
chunk2, _ = track._chunk_queue[1]
self.assertEqual(len(chunk1), track._bytes_per_chunk)
self.assertEqual(len(chunk2), track._bytes_per_chunk)

async def test_add_audio_bytes_rejects_invalid_size(self):
"""Test that add_audio_bytes rejects audio not a multiple of chunk size."""
sample_rate = 16000
track = RawAudioTrack(sample_rate=sample_rate, num_10ms_chunks=4)

# Create audio that's not a multiple of 40ms chunk size
invalid_audio = bytes(track._bytes_per_chunk + 100)

with self.assertRaises(ValueError) as ctx:
track.add_audio_bytes(invalid_audio)

self.assertIn("40ms", str(ctx.exception))

async def test_recv_returns_correct_frame_size(self):
"""Test that recv() returns AudioFrames with correct sample count."""
sample_rate = 16000
num_chunks = 4 # 40ms
track = RawAudioTrack(sample_rate=sample_rate, num_10ms_chunks=num_chunks)

# Add one 40ms chunk of audio
audio_bytes = bytes(track._bytes_per_chunk)
track.add_audio_bytes(audio_bytes)

# Receive the frame
frame = await track.recv()

# Frame should have correct number of samples (40ms worth)
expected_samples = int(sample_rate * 40 / 1000) # 640 samples
self.assertEqual(frame.samples, expected_samples)

async def test_recv_silence_has_correct_size(self):
"""Test that silence frames have correct size when queue is empty."""
sample_rate = 16000
num_chunks = 4 # 40ms
track = RawAudioTrack(sample_rate=sample_rate, num_10ms_chunks=num_chunks)

# Don't add any audio - should get silence
frame = await track.recv()

# Silence frame should have correct number of samples
expected_samples = int(sample_rate * 40 / 1000) # 640 samples
self.assertEqual(frame.samples, expected_samples)

async def test_timestamp_advances_by_chunk_samples(self):
"""Test that timestamp advances correctly based on chunk size."""
sample_rate = 16000
num_chunks = 4 # 40ms
track = RawAudioTrack(sample_rate=sample_rate, num_10ms_chunks=num_chunks)

# Receive first frame and check its timestamp
frame1 = await track.recv()
# Receive second frame
frame2 = await track.recv()

# Timestamp should advance by samples_per_chunk between frames
self.assertIsNotNone(frame1.pts)
self.assertIsNotNone(frame2.pts)
expected_samples = int(sample_rate * 40 / 1000) # 640 samples
self.assertEqual(frame2.pts - frame1.pts, expected_samples)

def test_different_sample_rates(self):
"""Test chunk size calculation at different sample rates."""
test_cases = [
(8000, 4, 640), # 8kHz, 40ms = 320 samples * 2 bytes = 640 bytes
(16000, 4, 1280), # 16kHz, 40ms = 640 samples * 2 bytes = 1280 bytes
(24000, 4, 1920), # 24kHz, 40ms = 960 samples * 2 bytes = 1920 bytes
(48000, 4, 3840), # 48kHz, 40ms = 1920 samples * 2 bytes = 3840 bytes
]

for sample_rate, num_chunks, expected_bytes in test_cases:
track = RawAudioTrack(sample_rate=sample_rate, num_10ms_chunks=num_chunks)
self.assertEqual(track._bytes_per_chunk, expected_bytes)

def test_invalid_num_10ms_chunks_zero(self):
"""Test that num_10ms_chunks=0 raises ValueError."""
with self.assertRaises(ValueError) as ctx:
RawAudioTrack(sample_rate=16000, num_10ms_chunks=0)

self.assertIn("positive integer", str(ctx.exception))

def test_invalid_num_10ms_chunks_negative(self):
"""Test that negative num_10ms_chunks raises ValueError."""
with self.assertRaises(ValueError) as ctx:
RawAudioTrack(sample_rate=16000, num_10ms_chunks=-1)

self.assertIn("positive integer", str(ctx.exception))


if __name__ == "__main__":
unittest.main()