@@ -983,17 +983,14 @@ def test_get_all_samples(self, asset, stop_seconds):
983983 if stop_seconds == "duration" :
984984 stop_seconds = asset .duration_seconds
985985
986- samples = decoder .get_samples_played_in_range (
987- start_seconds = 0 , stop_seconds = stop_seconds
988- )
986+ samples = decoder .get_samples_played_in_range (stop_seconds = stop_seconds )
989987
990988 reference_frames = asset .get_frame_data_by_range (
991989 start = 0 , stop = asset .get_frame_index (pts_seconds = asset .duration_seconds ) + 1
992990 )
993991
994992 torch .testing .assert_close (samples .data , reference_frames )
995993 assert samples .sample_rate == asset .sample_rate
996-
997994 assert samples .pts_seconds == asset .get_frame_info (idx = 0 ).pts_seconds
998995
999996 @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
@@ -1079,15 +1076,15 @@ def test_single_channel(self):
10791076 asset = SINE_MONO_S32
10801077 decoder = AudioDecoder (asset .path )
10811078
1082- samples = decoder .get_samples_played_in_range (start_seconds = 0 , stop_seconds = 2 )
1079+ samples = decoder .get_samples_played_in_range (stop_seconds = 2 )
10831080 assert samples .data .shape [0 ] == asset .num_channels == 1
10841081
10851082 def test_format_conversion (self ):
10861083 asset = SINE_MONO_S32
10871084 decoder = AudioDecoder (asset .path )
10881085 assert decoder .metadata .sample_format == asset .sample_format == "s32"
10891086
1090- all_samples = decoder .get_samples_played_in_range (start_seconds = 0 )
1087+ all_samples = decoder .get_samples_played_in_range ()
10911088 assert all_samples .data .dtype == torch .float32
10921089
10931090 reference_frames = asset .get_frame_data_by_range (start = 0 , stop = asset .num_frames )
@@ -1164,7 +1161,7 @@ def test_sample_rate_conversion_stereo(self):
11641161 assert asset .sample_rate == 8000
11651162 assert asset .num_channels == 2
11661163 decoder = AudioDecoder (asset .path , sample_rate = 44_100 )
1167- decoder .get_samples_played_in_range (start_seconds = 0 )
1164+ decoder .get_samples_played_in_range ()
11681165
11691166 def test_downsample_empty_frame (self ):
11701167 # Non-regression test for
@@ -1184,13 +1181,13 @@ def test_downsample_empty_frame(self):
11841181 asset = NASA_AUDIO_MP3_44100
11851182 assert asset .sample_rate == 44_100
11861183 decoder = AudioDecoder (asset .path , sample_rate = 8_000 )
1187- frames_44100_to_8000 = decoder .get_samples_played_in_range (start_seconds = 0 )
1184+ frames_44100_to_8000 = decoder .get_samples_played_in_range ()
11881185
11891186 # Just checking correctness now
11901187 asset = NASA_AUDIO_MP3
11911188 assert asset .sample_rate == 8_000
11921189 decoder = AudioDecoder (asset .path )
1193- frames_8000 = decoder .get_samples_played_in_range (start_seconds = 0 )
1190+ frames_8000 = decoder .get_samples_played_in_range ()
11941191 torch .testing .assert_close (
11951192 frames_44100_to_8000 .data , frames_8000 .data , atol = 0.03 , rtol = 0
11961193 )
@@ -1214,4 +1211,11 @@ def test_s16_ffmpeg4_bug(self):
12141211 else contextlib .nullcontext ()
12151212 )
12161213 with cm :
1217- decoder .get_samples_played_in_range (start_seconds = 0 )
1214+ decoder .get_samples_played_in_range ()
1215+
1216+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
1217+ @pytest .mark .parametrize ("sample_rate" , (None , 8000 , 16_000 , 44_1000 ))
1218+ def test_samples_duration (self , asset , sample_rate ):
1219+ decoder = AudioDecoder (asset .path , sample_rate = sample_rate )
1220+ samples = decoder .get_samples_played_in_range (start_seconds = 1 , stop_seconds = 2 )
1221+ assert samples .duration_seconds == 1
0 commit comments