@@ -695,7 +695,7 @@ def test_get_frames_by_pts_in_range_audio(self, range, asset):
695695 decoder , start_seconds = start_seconds , stop_seconds = stop_seconds
696696 )
697697
698- assert_frames_equal (frames , reference_frames )
698+ torch . testing . assert_close (frames , reference_frames )
699699
700700 @pytest .mark .parametrize (
701701 "asset, expected_shape" , ((NASA_AUDIO , (2 , 1024 )), (NASA_AUDIO_MP3 , (2 , 576 )))
@@ -723,6 +723,46 @@ def test_decode_just_one_frame_at_boundaries(self, asset, expected_shape):
723723 )
724724 assert frames .shape == expected_shape
725725
726+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
727+ def test_multiple_calls (self , asset ):
728+
729+ def decode_stateless (start_seconds , stop_seconds ):
730+ decoder = create_from_file (str (asset .path ), seek_mode = "approximate" )
731+ add_audio_stream (decoder )
732+
733+ return get_frames_by_pts_in_range_audio (
734+ decoder , start_seconds = start_seconds , stop_seconds = stop_seconds
735+ )
736+
737+ decoder = create_from_file (str (asset .path ), seek_mode = "approximate" )
738+ add_audio_stream (decoder )
739+
740+ start_seconds , stop_seconds = 0 , 2
741+ frames = get_frames_by_pts_in_range_audio (
742+ decoder , start_seconds = start_seconds , stop_seconds = stop_seconds
743+ )
744+ torch .testing .assert_close (
745+ frames , decode_stateless (start_seconds , stop_seconds )
746+ )
747+
748+ start_seconds , stop_seconds = 3 , 4
749+ frames = get_frames_by_pts_in_range_audio (
750+ decoder , start_seconds = start_seconds , stop_seconds = stop_seconds
751+ )
752+ torch .testing .assert_close (
753+ frames , decode_stateless (start_seconds , stop_seconds )
754+ )
755+
756+ # TODO-AUDIO
757+ start_seconds , stop_seconds = 0 , 2
758+ frames = get_frames_by_pts_in_range_audio (
759+ decoder , start_seconds = start_seconds , stop_seconds = stop_seconds
760+ )
761+ with pytest .raises (AssertionError ):
762+ torch .testing .assert_close (
763+ frames , decode_stateless (start_seconds , stop_seconds )
764+ )
765+
726766
727767if __name__ == "__main__" :
728768 pytest .main ()
0 commit comments