3030 get_frames_at_indices ,
3131 get_frames_by_pts ,
3232 get_frames_by_pts_in_range ,
33+ get_frames_by_pts_in_range_audio ,
3334 get_frames_in_range ,
3435 get_json_metadata ,
3536 get_next_frame ,
@@ -638,20 +639,44 @@ def test_audio_bad_seek_mode(self):
638639 ):
639640 add_audio_stream (decoder )
640641
641- def test_audio_decode_all_samples_with_get_frames_by_pts_in_range (self ):
642- decoder = create_from_file (str (NASA_AUDIO .path ), seek_mode = "approximate" )
642+ # TODO-audio: this fails with NASA_AUDIO_MP3 because numFrame isn't in the
643+ # metadata
644+ # @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
645+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO ,))
646+ def test_audio_decode_all_samples_with_get_frames_by_pts_in_range (self , asset ):
647+ decoder = create_from_file (str (asset .path ), seek_mode = "approximate" )
643648 add_audio_stream (decoder )
644649
645650 reference_frames = [
646- NASA_AUDIO .get_frame_data_by_index (i ) for i in range (NASA_AUDIO .num_frames )
651+ asset .get_frame_data_by_index (i ) for i in range (asset .num_frames )
647652 ]
648- reference_frames = torch .stack (
649- reference_frames
650- ) # shape is (num_frames, C, num_samples_per_frame)
653+ # shape is (C, num_frames * num_samples_per_frame) while preserving frame order and boundaries
654+ reference_frames = torch .cat (reference_frames , dim = - 1 )
651655
652656 all_frames , * _ = get_frames_by_pts_in_range (
653- decoder , start_seconds = 0 , stop_seconds = NASA_AUDIO .duration_seconds
657+ decoder , start_seconds = 0 , stop_seconds = asset .duration_seconds
658+ )
659+ all_frames = torch .cat (all_frames .unbind (0 ), dim = - 1 )
660+
661+ assert_frames_equal (all_frames , reference_frames )
662+
663+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
664+ def test_audio_decode_all_samples_with_get_frames_by_pts_in_range_audio (
665+ self , asset
666+ ):
667+ decoder = create_from_file (str (asset .path ), seek_mode = "approximate" )
668+ add_audio_stream (decoder )
669+
670+ reference_frames = [
671+ asset .get_frame_data_by_index (i ) for i in range (asset .num_frames )
672+ ]
673+ # shape is (C, num_frames * num_samples_per_frame) while preserving frame order and boundaries
674+ reference_frames = torch .cat (reference_frames , dim = - 1 )
675+
676+ all_frames = get_frames_by_pts_in_range_audio (
677+ decoder , start_seconds = 0 , stop_seconds = asset .duration_seconds
654678 )
679+
655680 assert_frames_equal (all_frames , reference_frames )
656681
657682 @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
@@ -663,7 +688,6 @@ def test_audio_decode_all_samples_with_next(self, asset):
663688 asset .get_frame_data_by_index (i ) for i in range (asset .num_frames )
664689 ]
665690
666- # shape is (C, num_frames * num_samples_per_frame) while preserving frame order and boundaries
667691 reference_frames = torch .cat (reference_frames , dim = - 1 )
668692
669693 all_frames = []
@@ -673,7 +697,7 @@ def test_audio_decode_all_samples_with_next(self, asset):
673697 all_frames .append (frame )
674698 except IndexError :
675699 break
676- all_frames = torch .cat (all_frames , axis = - 1 )
700+ all_frames = torch .cat (all_frames , dim = - 1 )
677701
678702 assert_frames_equal (all_frames , reference_frames )
679703
@@ -696,8 +720,8 @@ def test_audio_get_frames_by_pts_in_range(self, start_seconds, stop_seconds):
696720 add_audio_stream (decoder )
697721
698722 reference_frames = NASA_AUDIO .get_frame_data_by_range (
699- start = NASA_AUDIO .pts_to_frame_index ( start_seconds ),
700- stop = NASA_AUDIO .pts_to_frame_index ( stop_seconds ) + 1 ,
723+ start = NASA_AUDIO .get_frame_index ( pts_seconds = start_seconds ),
724+ stop = NASA_AUDIO .get_frame_index ( pts_seconds = stop_seconds ) + 1 ,
701725 )
702726 frames , _ , _ = get_frames_by_pts_in_range (
703727 decoder , start_seconds = start_seconds , stop_seconds = stop_seconds
@@ -722,7 +746,7 @@ def test_audio_seek_and_next(self):
722746 pts = 2
723747 # Need +1 because we're not at frames boundaries
724748 reference_frame = NASA_AUDIO .get_frame_data_by_index (
725- NASA_AUDIO .pts_to_frame_index ( pts ) + 1
749+ NASA_AUDIO .get_frame_index ( pts_seconds = pts ) + 1
726750 )
727751 seek_to_pts (decoder , pts )
728752 frame , _ , _ = get_next_frame (decoder )
@@ -731,7 +755,7 @@ def test_audio_seek_and_next(self):
731755 # Seeking forward is OK
732756 pts = 4
733757 reference_frame = NASA_AUDIO .get_frame_data_by_index (
734- NASA_AUDIO .pts_to_frame_index ( pts ) + 1
758+ NASA_AUDIO .get_frame_index ( pts_seconds = pts ) + 1
735759 )
736760 seek_to_pts (decoder , pts )
737761 frame , _ , _ = get_next_frame (decoder )
@@ -747,7 +771,7 @@ def test_audio_seek_and_next(self):
747771 # the "next: one without seeking. This assertion exists to illutrate
748772 # what currently hapens, but it's obviously *wrong*.
749773 reference_frame = NASA_AUDIO .get_frame_data_by_index (
750- NASA_AUDIO .pts_to_frame_index ( prev_pts ) + 2
774+ NASA_AUDIO .get_frame_index ( pts_seconds = prev_pts ) + 2
751775 )
752776 assert_frames_equal (frame , reference_frame )
753777
0 commit comments