Skip to content

Commit 7b09315

Browse files
committed
Add tests
1 parent 8deb079 commit 7b09315

File tree

3 files changed

+76
-9
lines changed

3 files changed

+76
-9
lines changed

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ def __init__(
2525
source: Union[str, Path, bytes, Tensor],
2626
*,
2727
stream_index: Optional[int] = None,
28+
sample_rate: Optional[int] = None,
2829
):
2930
self._decoder = create_decoder(source=source, seek_mode="approximate")
3031

31-
core.add_audio_stream(self._decoder, stream_index=stream_index)
32+
core.add_audio_stream(
33+
self._decoder, stream_index=stream_index, sample_rate=sample_rate
34+
)
3235

3336
(
3437
self.metadata,
@@ -39,6 +42,10 @@ def __init__(
3942
decoder=self._decoder, stream_index=stream_index, media_type="audio"
4043
)
4144
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
45+
self._source_sample_rate = self.metadata.sample_rate
46+
self._desired_sample_rate = (
47+
sample_rate if sample_rate is not None else self._source_sample_rate
48+
)
4249

4350
def get_samples_played_in_range(
4451
self, start_seconds: float, stop_seconds: Optional[float] = None
@@ -75,11 +82,7 @@ def get_samples_played_in_range(
7582
# So we do some basic math to figure out the position of the view that
7683
# we'll return.
7784

78-
# TODO: sample_rate is either the original one from metadata, or the
79-
# user-specified one (NIY)
80-
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
81-
sample_rate = self.metadata.sample_rate
82-
85+
sample_rate = self._desired_sample_rate
8386
# TODO: metadata's sample_rate should probably not be Optional
8487
assert sample_rate is not None # mypy.
8588

@@ -94,7 +97,7 @@ def get_samples_played_in_range(
9497
output_pts_seconds = first_pts
9598

9699
num_samples = frames.shape[1]
97-
last_pts = first_pts + num_samples / self.metadata.sample_rate
100+
last_pts = first_pts + num_samples / sample_rate
98101
if stop_seconds is not None and stop_seconds < last_pts:
99102
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate)
100103
else:

test/decoders/test_decoders.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
NASA_AUDIO_MP3,
2727
NASA_VIDEO,
2828
SINE_MONO_S32,
29+
SINE_MONO_S32_44100,
30+
SINE_MONO_S32_8000,
2931
)
3032

3133

@@ -1088,3 +1090,65 @@ def test_format_conversion(self):
10881090

10891091
reference_frames = asset.get_frame_data_by_range(start=0, stop=asset.num_frames)
10901092
torch.testing.assert_close(all_samples.data, reference_frames)
1093+
1094+
@pytest.mark.parametrize(
1095+
"start_seconds, stop_seconds",
1096+
(
1097+
(0, None),
1098+
(0, 4),
1099+
(0, 3),
1100+
(2, None),
1101+
(2, 3),
1102+
),
1103+
)
1104+
def test_sample_rate_conversion(self, start_seconds, stop_seconds):
1105+
# When start_seconds is not exactly 0, we have to increase the tolerance
1106+
# a bit. This is because sample_rate conversion relies on a sliding
1107+
# window of samples: if we start a stream in the middle, the first few
1108+
# samples aren't able to take advantage of the preceeding samples.
1109+
atol = 1e-4 if start_seconds == 0 else 1e-2
1110+
rtol = 1e-6
1111+
1112+
# Upsample
1113+
decoder = AudioDecoder(SINE_MONO_S32_44100.path)
1114+
assert decoder.metadata.sample_rate == 44_100
1115+
frames_44100_native = decoder.get_samples_played_in_range(
1116+
start_seconds=start_seconds, stop_seconds=stop_seconds
1117+
)
1118+
assert frames_44100_native.sample_rate == 44_100
1119+
1120+
decoder = AudioDecoder(SINE_MONO_S32.path, sample_rate=44_100)
1121+
frames_upsampled_to_44100 = decoder.get_samples_played_in_range(
1122+
start_seconds=start_seconds, stop_seconds=stop_seconds
1123+
)
1124+
assert decoder.metadata.sample_rate == 16_000
1125+
assert frames_upsampled_to_44100.sample_rate == 44_100
1126+
1127+
torch.testing.assert_close(
1128+
frames_upsampled_to_44100.data,
1129+
frames_44100_native.data,
1130+
atol=atol,
1131+
rtol=rtol,
1132+
)
1133+
1134+
# Downsample
1135+
decoder = AudioDecoder(SINE_MONO_S32_8000.path)
1136+
assert decoder.metadata.sample_rate == 8000
1137+
frames_8000_native = decoder.get_samples_played_in_range(
1138+
start_seconds=start_seconds, stop_seconds=stop_seconds
1139+
)
1140+
assert frames_8000_native.sample_rate == 8000
1141+
1142+
decoder = AudioDecoder(SINE_MONO_S32.path, sample_rate=8000)
1143+
frames_downsampled_to_8000 = decoder.get_samples_played_in_range(
1144+
start_seconds=start_seconds, stop_seconds=stop_seconds
1145+
)
1146+
assert decoder.metadata.sample_rate == 16_000
1147+
assert frames_downsampled_to_8000.sample_rate == 8000
1148+
1149+
torch.testing.assert_close(
1150+
frames_downsampled_to_8000.data,
1151+
frames_8000_native.data,
1152+
atol=atol,
1153+
rtol=rtol,
1154+
)

test/decoders/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,11 +884,11 @@ def test_decode_before_frame_start(self):
884884
torch.testing.assert_close(frames, all_frames)
885885

886886
def test_sample_rate_conversion(self):
887-
def get_all_frames(asset, sample_rate=None):
887+
def get_all_frames(asset, sample_rate=None, stop_seconds=None):
888888
decoder = create_from_file(str(asset.path), seek_mode="approximate")
889889
add_audio_stream(decoder, sample_rate=sample_rate)
890890
frames, *_ = get_frames_by_pts_in_range_audio(
891-
decoder, start_seconds=0, stop_seconds=None
891+
decoder, start_seconds=0, stop_seconds=stop_seconds
892892
)
893893
return frames
894894

0 commit comments

Comments
 (0)