@@ -195,6 +195,29 @@ def test_get_frame_at_pts_audio(self, seek_mode):
195195 with pytest .raises (AssertionError ):
196196 assert_frames_equal (next_frame , reference_frame6 )
197197
198+ def test_get_frame_at_pts_audio_bad (self ):
199+ decoder = create_from_file (str (NASA_AUDIO .path ))
200+ add_audio_stream (decoder = decoder )
201+
202+ reference_frame6 = NASA_AUDIO .get_frame_data_by_index (
203+ INDEX_OF_AUDIO_FRAME_AFTER_SEEKING_AT_6
204+ )
205+ frame6 , _ , _ = get_frame_at_pts (decoder , 6.05 )
206+ # See Note [Seek offset for audio].
207+ # The frame played at 6.05 should be the reference frame, but because
208+ # 6.05 isn't exactly the beginning of that frame, the samples are
209+ # decoded incorrectly.
210+ # TODO Fix this.
211+ with pytest .raises (AssertionError ):
212+ assert_frames_equal (frame6 , reference_frame6 )
213+
214+ # And yet another quirk: if we try to decode it again, we actually end
215+ # up with the samples being correctly decoded. This is because we have a
216+ # custom logic within getFramePlayedAt() that resets desiredPts to the
217+ # pts of the beginning of the frame in some very specific cases.
218+ frame6 , _ , _ = get_frame_at_pts (decoder , 6.05 )
219+ assert_frames_equal (frame6 , reference_frame6 )
220+
198221 @pytest .mark .parametrize ("test_ref" , (NASA_VIDEO , NASA_AUDIO ))
199222 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
200223 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
@@ -779,34 +802,48 @@ def test_cuda_decoder(self):
779802 )
780803
781804 def test_get_same_frame_twice (self ):
805+ # Non-regression tests that were useful while developing audio support.
782806 def make_decoder ():
783807 decoder = create_from_file (str (NASA_AUDIO .path ))
784808 add_audio_stream (decoder )
785809 return decoder
786810
787811 for frame_index in (0 , 10 , 15 ):
812+ ref = NASA_AUDIO .get_frame_data_by_index (frame_index )
813+
788814 decoder = make_decoder ()
789815 a = get_frame_at_index (decoder , frame_index = frame_index )
790816 b = get_frame_at_index (decoder , frame_index = frame_index )
791817 torch .testing .assert_close (a , b )
818+ torch .testing .assert_close (a [0 ], ref )
792819
793820 decoder = make_decoder ()
794821 a = get_frames_at_indices (decoder , frame_indices = [frame_index ])
795822 b = get_frames_at_indices (decoder , frame_indices = [frame_index ])
796823 torch .testing .assert_close (a , b )
824+ torch .testing .assert_close (a [0 ][0 ], ref )
797825
798826 decoder = make_decoder ()
799827 a = get_frames_in_range (decoder , start = frame_index , stop = frame_index + 1 )
800828 b = get_frames_in_range (decoder , start = frame_index , stop = frame_index + 1 )
801829 torch .testing .assert_close (a , b )
830+ torch .testing .assert_close (a [0 ][0 ], ref )
802831
803- pts_at_frame_start = 0
832+ pts_at_frame_start = 0 # 0 corresponds exactly to a frame start
833+ index_of_frame_at_0 = 0
804834 pts_not_at_frame_start = 2 # second 2 is in the middle of a frame
805- for pts in (pts_at_frame_start , pts_not_at_frame_start ):
835+ index_of_frame_at_2 = 31
836+ for pts , frame_index in (
837+ (pts_at_frame_start , index_of_frame_at_0 ),
838+ (pts_not_at_frame_start , index_of_frame_at_2 ),
839+ ):
840+ ref = NASA_AUDIO .get_frame_data_by_index (frame_index )
841+
806842 decoder = make_decoder ()
807843 a = get_frames_by_pts (decoder , timestamps = [pts ])
808844 b = get_frames_by_pts (decoder , timestamps = [pts ])
809845 torch .testing .assert_close (a , b )
846+ torch .testing .assert_close (a [0 ][0 ], ref )
810847
811848 decoder = make_decoder ()
812849 a = get_frames_by_pts_in_range (
@@ -816,11 +853,15 @@ def make_decoder():
816853 decoder , start_seconds = pts , stop_seconds = pts + 1e-4
817854 )
818855 torch .testing .assert_close (a , b )
856+ torch .testing .assert_close (a [0 ][0 ], ref )
819857
820858 decoder = make_decoder ()
821859 a = get_frame_at_pts (decoder , seconds = pts_at_frame_start )
822860 b = get_frame_at_pts (decoder , seconds = pts_at_frame_start )
823861 torch .testing .assert_close (a , b )
862+ torch .testing .assert_close (
863+ a [0 ], NASA_AUDIO .get_frame_data_by_index (index_of_frame_at_0 )
864+ )
824865
825866 decoder = make_decoder ()
826867 a_frame , a_pts , a_duration = get_frame_at_pts (
@@ -831,8 +872,17 @@ def make_decoder():
831872 )
832873 torch .testing .assert_close (a_pts , b_pts )
833874 torch .testing .assert_close (a_duration , b_duration )
875+ # TODO fix this. These checks should pass
834876 with pytest .raises (AssertionError ):
835877 torch .testing .assert_close (a_frame , b_frame )
878+ with pytest .raises (AssertionError ):
879+ torch .testing .assert_close (
880+ a_frame , NASA_AUDIO .get_frame_data_by_index (index_of_frame_at_2 )
881+ )
882+ # But second time works ¯\_(ツ)_/¯A (see also test_get_frame_at_pts_audio_bad())
883+ torch .testing .assert_close (
884+ b_frame , NASA_AUDIO .get_frame_data_by_index (index_of_frame_at_2 )
885+ )
836886
837887 decoder = make_decoder ()
838888 seek_to_pts (decoder , pts_at_frame_start )
@@ -841,13 +891,15 @@ def make_decoder():
841891 b = get_next_frame (decoder )
842892 torch .testing .assert_close (a , b )
843893
844- # TODO: Wait WTFFF, this should not pass
845894 decoder = make_decoder ()
846895 seek_to_pts (decoder , seconds = pts_not_at_frame_start )
847896 a = get_next_frame (decoder )
848897 seek_to_pts (decoder , seconds = pts_not_at_frame_start )
849898 b = get_next_frame (decoder )
850899 torch .testing .assert_close (a , b )
900+ torch .testing .assert_close (
901+ a [0 ], NASA_AUDIO .get_frame_data_by_index (index_of_frame_at_2 + 1 )
902+ )
851903
852904
853905if __name__ == "__main__" :
0 commit comments