Skip to content

Commit f4d1a50

Browse files
committed
support appending to pcm_data
1 parent 2758c08 commit f4d1a50

File tree

2 files changed

+254
-2
lines changed

2 files changed

+254
-2
lines changed

getstream/video/rtc/track_util.py

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class PcmData(NamedTuple):
3838

3939
format: str
4040
sample_rate: int
41-
samples: NDArray
41+
samples: NDArray = np.array([], dtype=np.int16)
4242
pts: Optional[int] = None # Presentation timestamp
4343
dts: Optional[int] = None # Decode timestamp
4444
time_base: Optional[float] = None # Time base for converting timestamps to seconds
@@ -521,7 +521,8 @@ def to_float32(self) -> "PcmData":
521521
).samples
522522

523523
# Convert to float32 and scale if needed
524-
if self.format == "s16" or (
524+
fmt = (self.format or "").lower()
525+
if fmt in ("s16", "int16") or (
525526
isinstance(arr, np.ndarray) and arr.dtype == np.int16
526527
):
527528
arr_f32 = arr.astype(np.float32) / 32768.0
@@ -539,6 +540,174 @@ def to_float32(self) -> "PcmData":
539540
channels=self.channels,
540541
)
541542

543+
def append(self, other: "PcmData") -> "PcmData":
544+
"""Append another PcmData to this one and return a new instance.
545+
546+
The input chunk is adjusted to match this instance's sample rate,
547+
channel count, and sample format before concatenation.
548+
549+
Notes:
550+
- Preserves shape semantics: mono as 1D, multi-channel as 2D [channels, samples].
551+
- Keeps metadata (sample_rate, format, channels, pts/dts/time_base) from self.
552+
- Does not modify self; returns a new PcmData.
553+
"""
554+
555+
# Early exits for empty cases
556+
def _is_empty(arr: Any) -> bool:
557+
try:
558+
return isinstance(arr, np.ndarray) and arr.size == 0
559+
except Exception:
560+
return False
561+
562+
# Normalize numpy arrays from bytes-like if needed
563+
def _ensure_ndarray(pcm: "PcmData") -> np.ndarray:
564+
if isinstance(pcm.samples, np.ndarray):
565+
return pcm.samples
566+
return PcmData.from_bytes(
567+
pcm.to_bytes(),
568+
sample_rate=pcm.sample_rate,
569+
format=pcm.format,
570+
channels=pcm.channels,
571+
).samples
572+
573+
# Adjust other to match sample rate and channels first
574+
other_adj = other
575+
if (
576+
other_adj.sample_rate != self.sample_rate
577+
or other_adj.channels != self.channels
578+
):
579+
other_adj = other_adj.resample(
580+
self.sample_rate, target_channels=self.channels
581+
)
582+
583+
# Then adjust format to match
584+
fmt = (self.format or "").lower()
585+
if fmt in ("f32", "float32"):
586+
other_adj = other_adj.to_float32()
587+
elif fmt in ("s16", "int16"):
588+
# Ensure int16 dtype and mark as s16
589+
arr = _ensure_ndarray(other_adj)
590+
if arr.dtype != np.int16:
591+
if other_adj.format == "f32":
592+
arr = (np.clip(arr.astype(np.float32), -1.0, 1.0) * 32767.0).astype(
593+
np.int16
594+
)
595+
else:
596+
arr = arr.astype(np.int16)
597+
other_adj = PcmData(
598+
samples=arr,
599+
sample_rate=other_adj.sample_rate,
600+
format="s16",
601+
pts=other_adj.pts,
602+
dts=other_adj.dts,
603+
time_base=other_adj.time_base,
604+
channels=other_adj.channels,
605+
)
606+
else:
607+
# For unknown formats, fallback to bytes round-trip in self's format
608+
other_adj = PcmData.from_bytes(
609+
other_adj.to_bytes(),
610+
sample_rate=self.sample_rate,
611+
format=self.format,
612+
channels=self.channels,
613+
)
614+
615+
# Ensure ndarrays for concatenation
616+
self_arr = _ensure_ndarray(self)
617+
other_arr = _ensure_ndarray(other_adj)
618+
619+
# If either is empty, return the other while preserving self's metadata
620+
if _is_empty(self_arr):
621+
# Conform shape to target channels semantics and dtype
622+
if isinstance(other_arr, np.ndarray):
623+
if (self.channels or 1) == 1 and other_arr.ndim > 1:
624+
other_arr = other_arr.reshape(-1)
625+
target_dtype = (
626+
np.float32
627+
if (self.format or "").lower() in ("f32", "float32")
628+
else np.int16
629+
)
630+
other_arr = other_arr.astype(target_dtype, copy=False)
631+
return PcmData(
632+
samples=other_arr,
633+
sample_rate=self.sample_rate,
634+
format=self.format,
635+
pts=self.pts,
636+
dts=self.dts,
637+
time_base=self.time_base,
638+
channels=self.channels,
639+
)
640+
if _is_empty(other_arr):
641+
return self
642+
643+
ch = max(1, int(self.channels or 1))
644+
645+
# Concatenate respecting shape conventions
646+
if ch == 1:
647+
# Mono: keep 1D shape
648+
if self_arr.ndim > 1:
649+
self_arr = self_arr.reshape(-1)
650+
if other_arr.ndim > 1:
651+
other_arr = other_arr.reshape(-1)
652+
out = np.concatenate([self_arr, other_arr])
653+
# Enforce dtype based on format
654+
if (self.format or "").lower() in (
655+
"f32",
656+
"float32",
657+
) and out.dtype != np.float32:
658+
out = out.astype(np.float32)
659+
elif (self.format or "").lower() in (
660+
"s16",
661+
"int16",
662+
) and out.dtype != np.int16:
663+
out = out.astype(np.int16)
664+
return PcmData(
665+
samples=out,
666+
sample_rate=self.sample_rate,
667+
format=self.format,
668+
pts=self.pts,
669+
dts=self.dts,
670+
time_base=self.time_base,
671+
channels=self.channels,
672+
)
673+
else:
674+
# Multi-channel: normalize to (channels, samples)
675+
def _to_cmaj(arr: np.ndarray, channels: int) -> np.ndarray:
676+
if arr.ndim == 2:
677+
if arr.shape[0] == channels:
678+
return arr
679+
if arr.shape[1] == channels:
680+
return arr.T
681+
# Ambiguous; assume time-major and transpose
682+
return arr.T
683+
# 1D input: replicate across channels
684+
return np.tile(arr.reshape(1, -1), (channels, 1))
685+
686+
self_cmaj = _to_cmaj(self_arr, ch)
687+
other_cmaj = _to_cmaj(other_arr, ch)
688+
out = np.concatenate([self_cmaj, other_cmaj], axis=1)
689+
# Enforce dtype based on format
690+
if (self.format or "").lower() in (
691+
"f32",
692+
"float32",
693+
) and out.dtype != np.float32:
694+
out = out.astype(np.float32)
695+
elif (self.format or "").lower() in (
696+
"s16",
697+
"int16",
698+
) and out.dtype != np.int16:
699+
out = out.astype(np.int16)
700+
701+
return PcmData(
702+
samples=out,
703+
sample_rate=self.sample_rate,
704+
format=self.format,
705+
pts=self.pts,
706+
dts=self.dts,
707+
time_base=self.time_base,
708+
channels=self.channels,
709+
)
710+
542711
@classmethod
543712
def from_response(
544713
cls,

tests/rtc/test_pcm_data.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,86 @@ def test_to_float32_converts_int16_and_preserves_metadata():
293293
f32_2 = f32.to_float32()
294294
assert f32_2.samples.dtype == np.float32
295295
assert np.allclose(f32_2.samples, f32.samples, atol=1e-7)
296+
297+
298+
def test_append_mono_s16_concatenates_and_preserves_format():
299+
sr = 16000
300+
a = np.array([1, 2, 3, 4], dtype=np.int16)
301+
b = np.array([5, 6], dtype=np.int16)
302+
303+
pcm_a = PcmData(samples=a, sample_rate=sr, format="s16", channels=1)
304+
pcm_b = PcmData(samples=b, sample_rate=sr, format="s16", channels=1)
305+
306+
out = pcm_a.append(pcm_b)
307+
308+
assert out.format == "s16"
309+
assert out.channels == 1
310+
assert isinstance(out.samples, np.ndarray)
311+
assert out.samples.dtype == np.int16
312+
assert out.samples.ndim == 1
313+
assert out.sample_rate == sr
314+
assert np.array_equal(out.samples, np.array([1, 2, 3, 4, 5, 6], dtype=np.int16))
315+
316+
317+
def test_append_resamples_and_converts_to_match_target_format():
318+
# Target is float32 stereo 48kHz
319+
base = np.array([[0.0, 0.1, -0.1], [0.0, 0.1, -0.1]], dtype=np.float32)
320+
pcm_target = PcmData(samples=base, sample_rate=48000, format="f32", channels=2)
321+
322+
# Other is s16 mono 16kHz
323+
other_raw = np.array([1000, -1000, 1000, -1000, 1000, -1000], dtype=np.int16)
324+
pcm_other = PcmData(samples=other_raw, sample_rate=16000, format="s16", channels=1)
325+
326+
# Pre-compute expected resampled length by using the same resample pipeline
327+
other_resampled = pcm_other.resample(48000, target_channels=2).to_float32()
328+
if other_resampled.samples.ndim == 2:
329+
expected_added = other_resampled.samples.shape[1]
330+
else:
331+
expected_added = other_resampled.samples.shape[0]
332+
333+
out = pcm_target.append(pcm_other)
334+
335+
# Check format/channels preserved and dtype matches
336+
assert out.format == "f32"
337+
assert out.channels == 2
338+
assert isinstance(out.samples, np.ndarray) and out.samples.dtype == np.float32
339+
assert out.samples.shape[0] == 2
340+
341+
# First part must equal the original base (append should not alter original)
342+
assert np.allclose(out.samples[:, : base.shape[1]], base)
343+
344+
# Total length should be base + resampled other
345+
assert out.samples.shape[1] == base.shape[1] + expected_added
346+
347+
348+
def test_append_empty_buffer_float32_adjusts_other_and_keeps_meta():
349+
# Create an empty buffer specifying desired output meta using alternate format name
350+
buffer = PcmData(format="float32", sample_rate=16000, channels=1)
351+
352+
# Other is int16 stereo at 48kHz, small ramp
353+
other = np.array(
354+
[[1000, -1000, 500, -500], [-1000, 1000, -500, 500]], dtype=np.int16
355+
)
356+
pcm_other = PcmData(samples=other, sample_rate=48000, format="s16", channels=2)
357+
358+
# Expected result if we first resample/downmix then convert to float32
359+
expected_pcm = pcm_other.resample(16000, target_channels=1).to_float32()
360+
361+
# Append to the empty buffer
362+
out = buffer.append(pcm_other)
363+
364+
# Metadata should be preserved from buffer
365+
assert out.format in ("f32", "float32")
366+
assert out.sample_rate == 16000
367+
assert out.channels == 1
368+
369+
# Data should match expected (mono float32)
370+
assert isinstance(out.samples, np.ndarray)
371+
assert out.samples.dtype == np.float32
372+
assert out.samples.ndim == 1
373+
# Normalize expected to 1D if needed
374+
if isinstance(expected_pcm.samples, np.ndarray) and expected_pcm.samples.ndim == 2:
375+
expected_samples = expected_pcm.samples.reshape(-1)
376+
else:
377+
expected_samples = expected_pcm.samples
378+
assert np.allclose(out.samples[-expected_samples.shape[0] :], expected_samples)

0 commit comments

Comments
 (0)