@@ -741,11 +741,9 @@ def test_decode_start_equal_stop(self, asset):
741741
742742 @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
743743 def test_multiple_calls (self , asset ):
744- # Ensure that multiple calls are OK as long as we're decoding
745- # "sequentially", i.e. we don't require a backwards seek.
746- # And ensure a proper error is raised in such case.
747- # TODO-AUDIO We shouldn't error, we should just implement the seeking
748- # back to the beginning of the stream.
744+ # Ensure that multiple calls to get_frames_by_pts_in_range_audio on the
745+ # same decoder are supported, whether it involves forward seeks or
746+ # backwards seeks.
749747
750748 def get_reference_frames (start_seconds , stop_seconds ):
751749 # This stateless helper exists for convenience, to avoid
@@ -794,23 +792,22 @@ def get_reference_frames(start_seconds, stop_seconds):
794792 frames , get_reference_frames (start_seconds , stop_seconds )
795793 )
796794
797- # but starting immediately on the same frame raises
798- expected_match = "Audio decoder cannot seek backwards"
799- with pytest .raises (RuntimeError , match = expected_match ):
800- get_frames_by_pts_in_range_audio (
801- decoder , start_seconds = stop_seconds , stop_seconds = 6
802- )
795+ # starting immediately on the same frame is OK
796+ frames = get_frames_by_pts_in_range_audio (
797+ decoder , start_seconds = stop_seconds , stop_seconds = 6
798+ )
799+ torch .testing .assert_close (frames , get_reference_frames (stop_seconds , 6 ))
803800
804- with pytest . raises ( RuntimeError , match = expected_match ):
805- get_frames_by_pts_in_range_audio (
806- decoder , start_seconds = stop_seconds + 1e-4 , stop_seconds = 6
807- )
801+ get_frames_by_pts_in_range_audio (
802+ decoder , start_seconds = stop_seconds + 1e-4 , stop_seconds = 6
803+ )
804+ torch . testing . assert_close ( frames , get_reference_frames ( stop_seconds , 6 ) )
808805
809- # and seeking backwards doesn't work either
810- with pytest . raises ( RuntimeError , match = expected_match ):
811- frames = get_frames_by_pts_in_range_audio (
812- decoder , start_seconds = 0 , stop_seconds = 2
813- )
806+ # seeking backwards
807+ frames = get_frames_by_pts_in_range_audio (
808+ decoder , start_seconds = 0 , stop_seconds = 2
809+ )
810+ torch . testing . assert_close ( frames , get_reference_frames ( 0 , 2 ) )
814811
815812
816813if __name__ == "__main__" :
0 commit comments