@@ -747,19 +747,18 @@ def test_decode_start_equal_stop(self, asset):
747747
748748 @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
749749 def test_multiple_calls (self , asset ):
750- # Ensure that multiple calls are OK as long as we're decoding
751- # "sequentially", i.e. we don't require a backwards seek.
752- # And ensure a proper error is raised in such case.
753- # TODO-AUDIO We shouldn't error, we should just implement the seeking
754- # back to the beginning of the stream.
750+ # Ensure that multiple calls to get_frames_by_pts_in_range_audio on the
751+ # same decoder are supported and correct, whether it involves forward
752+ # seeks or backwards seeks.
755753
756754 def get_reference_frames (start_seconds , stop_seconds ):
757- # This stateless helper exists for convenience, to avoid
758- # complicating this test with pts-to-index conversions. Eventually
759- # we should remove it and just rely on the asset's methods.
760- # Using this helper is OK for now: we're comparing a decoder which
761- # seeks multiple times with a decoder which seeks only once (the one
762- # here, treated as the reference)
755+ # Usually we get the reference frames from the asset's methods, but
756+ # for this specific test, this helper is more convenient, because
757+ # relying on the asset would force us to convert all timestamps into
758+ # indices.
759+ # Ultimately, this test compares a "stateful decoder" which calls
760+ # `get_frames_by_pts_in_range_audio()`` multiple times with a
761+ # "stateless decoder" (the one here, treated as the reference)
763762 decoder = create_from_file (str (asset .path ), seek_mode = "approximate" )
764763 add_audio_stream (decoder )
765764
@@ -800,23 +799,30 @@ def get_reference_frames(start_seconds, stop_seconds):
800799 frames , get_reference_frames (start_seconds , stop_seconds )
801800 )
802801
803- # but starting immediately on the same frame raises
804- expected_match = "Audio decoder cannot seek backwards"
805- with pytest .raises (RuntimeError , match = expected_match ):
806- get_frames_by_pts_in_range_audio (
807- decoder , start_seconds = stop_seconds , stop_seconds = 6
808- )
802+ # starting immediately on the same frame is OK
803+ start_seconds , stop_seconds = stop_seconds , 6
804+ frames = get_frames_by_pts_in_range_audio (
805+ decoder , start_seconds = start_seconds , stop_seconds = stop_seconds
806+ )
807+ torch .testing .assert_close (
808+ frames , get_reference_frames (start_seconds , stop_seconds )
809+ )
809810
810- with pytest .raises (RuntimeError , match = expected_match ):
811- get_frames_by_pts_in_range_audio (
812- decoder , start_seconds = stop_seconds + 1e-4 , stop_seconds = 6
813- )
811+ get_frames_by_pts_in_range_audio (
812+ decoder , start_seconds = start_seconds + 1e-4 , stop_seconds = stop_seconds
813+ )
814+ torch .testing .assert_close (
815+ frames , get_reference_frames (start_seconds , stop_seconds )
816+ )
814817
815- # and seeking backwards doesn't work either
816- with pytest .raises (RuntimeError , match = expected_match ):
817- frames = get_frames_by_pts_in_range_audio (
818- decoder , start_seconds = 0 , stop_seconds = 2
819- )
818+ # seeking backwards
819+ start_seconds , stop_seconds = 0 , 2
820+ frames = get_frames_by_pts_in_range_audio (
821+ decoder , start_seconds = start_seconds , stop_seconds = stop_seconds
822+ )
823+ torch .testing .assert_close (
824+ frames , get_reference_frames (start_seconds , stop_seconds )
825+ )
820826
821827
822828if __name__ == "__main__" :
0 commit comments