Skip to content

Commit 2e6f0ed

Browse files
authored
Add default for start_seconds parameter of get_samples_played_in_range (#588)
1 parent 5d91fd1 commit 2e6f0ed

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

examples/audio_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def play_audio(samples):
6262
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` method,
6363
# which returns an :class:`~torchcodec.AudioSamples` object:
6464

65-
samples = decoder.get_samples_played_in_range(start_seconds=0)
65+
samples = decoder.get_samples_played_in_range()
6666

6767
print(samples)
6868
play_audio(samples)

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,16 @@ def __init__(
7070
sample_rate if sample_rate is not None else self.metadata.sample_rate
7171
)
7272

73-
# TODO-AUDIO: start_seconds should be 0 by default
7473
def get_samples_played_in_range(
75-
self, start_seconds: float, stop_seconds: Optional[float] = None
74+
self, start_seconds: float = 0.0, stop_seconds: Optional[float] = None
7675
) -> AudioSamples:
7776
"""Returns audio samples in the given range.
7877
7978
Samples are in the half open range [start_seconds, stop_seconds).
8079
8180
Args:
8281
start_seconds (float): Time, in seconds, of the start of the
83-
range.
82+
range. Default: 0.
8483
stop_seconds (float): Time, in seconds, of the end of the
8584
range. As a half open range, the end is excluded.
8685

test/decoders/test_decoders.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -983,9 +983,7 @@ def test_get_all_samples(self, asset, stop_seconds):
983983
if stop_seconds == "duration":
984984
stop_seconds = asset.duration_seconds
985985

986-
samples = decoder.get_samples_played_in_range(
987-
start_seconds=0, stop_seconds=stop_seconds
988-
)
986+
samples = decoder.get_samples_played_in_range(stop_seconds=stop_seconds)
989987

990988
reference_frames = asset.get_frame_data_by_range(
991989
start=0, stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1
@@ -1078,15 +1076,15 @@ def test_single_channel(self):
10781076
asset = SINE_MONO_S32
10791077
decoder = AudioDecoder(asset.path)
10801078

1081-
samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=2)
1079+
samples = decoder.get_samples_played_in_range(stop_seconds=2)
10821080
assert samples.data.shape[0] == asset.num_channels == 1
10831081

10841082
def test_format_conversion(self):
10851083
asset = SINE_MONO_S32
10861084
decoder = AudioDecoder(asset.path)
10871085
assert decoder.metadata.sample_format == asset.sample_format == "s32"
10881086

1089-
all_samples = decoder.get_samples_played_in_range(start_seconds=0)
1087+
all_samples = decoder.get_samples_played_in_range()
10901088
assert all_samples.data.dtype == torch.float32
10911089

10921090
reference_frames = asset.get_frame_data_by_range(start=0, stop=asset.num_frames)
@@ -1163,7 +1161,7 @@ def test_sample_rate_conversion_stereo(self):
11631161
assert asset.sample_rate == 8000
11641162
assert asset.num_channels == 2
11651163
decoder = AudioDecoder(asset.path, sample_rate=44_100)
1166-
decoder.get_samples_played_in_range(start_seconds=0)
1164+
decoder.get_samples_played_in_range()
11671165

11681166
def test_downsample_empty_frame(self):
11691167
# Non-regression test for
@@ -1183,13 +1181,13 @@ def test_downsample_empty_frame(self):
11831181
asset = NASA_AUDIO_MP3_44100
11841182
assert asset.sample_rate == 44_100
11851183
decoder = AudioDecoder(asset.path, sample_rate=8_000)
1186-
frames_44100_to_8000 = decoder.get_samples_played_in_range(start_seconds=0)
1184+
frames_44100_to_8000 = decoder.get_samples_played_in_range()
11871185

11881186
# Just checking correctness now
11891187
asset = NASA_AUDIO_MP3
11901188
assert asset.sample_rate == 8_000
11911189
decoder = AudioDecoder(asset.path)
1192-
frames_8000 = decoder.get_samples_played_in_range(start_seconds=0)
1190+
frames_8000 = decoder.get_samples_played_in_range()
11931191
torch.testing.assert_close(
11941192
frames_44100_to_8000.data, frames_8000.data, atol=0.03, rtol=0
11951193
)
@@ -1213,7 +1211,7 @@ def test_s16_ffmpeg4_bug(self):
12131211
else contextlib.nullcontext()
12141212
)
12151213
with cm:
1216-
decoder.get_samples_played_in_range(start_seconds=0)
1214+
decoder.get_samples_played_in_range()
12171215

12181216
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
12191217
@pytest.mark.parametrize("sample_rate", (None, 8000, 16_000, 44_1000))

0 commit comments

Comments
 (0)