2222)
2323
2424from .utils import (
25+ all_supported_devices ,
2526 assert_frames_equal ,
2627 AV1_VIDEO ,
27- cpu_and_cuda ,
2828 get_ffmpeg_major_version ,
2929 H264_10BITS ,
3030 H265_10BITS ,
@@ -163,7 +163,7 @@ def test_create_fails(self):
163163 VideoDecoder (NASA_VIDEO .path , seek_mode = "blah" )
164164
165165 @pytest .mark .parametrize ("num_ffmpeg_threads" , (1 , 4 ))
166- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
166+ @pytest .mark .parametrize ("device" , all_supported_devices ())
167167 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
168168 def test_getitem_int (self , num_ffmpeg_threads , device , seek_mode ):
169169 decoder = VideoDecoder (
@@ -213,7 +213,7 @@ def test_getitem_numpy_int(self):
213213 assert_frames_equal (ref_frame1 , decoder [numpy .uint32 (1 )])
214214 assert_frames_equal (ref_frame180 , decoder [numpy .uint32 (180 )])
215215
216- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
216+ @pytest .mark .parametrize ("device" , all_supported_devices ())
217217 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
218218 def test_getitem_slice (self , device , seek_mode ):
219219 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -373,7 +373,7 @@ def test_device_instance(self):
373373 decoder = VideoDecoder (NASA_VIDEO .path , device = torch .device ("cpu" ))
374374 assert isinstance (decoder .metadata , VideoStreamMetadata )
375375
376- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
376+ @pytest .mark .parametrize ("device" , all_supported_devices ())
377377 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
378378 def test_getitem_fails (self , device , seek_mode ):
379379 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -390,7 +390,7 @@ def test_getitem_fails(self, device, seek_mode):
390390 with pytest .raises (TypeError , match = "Unsupported key type" ):
391391 frame = decoder [2.3 ] # noqa
392392
393- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
393+ @pytest .mark .parametrize ("device" , all_supported_devices ())
394394 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
395395 def test_iteration (self , device , seek_mode ):
396396 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -437,7 +437,7 @@ def test_iteration_slow(self):
437437
438438 assert iterations == len (decoder ) == 390
439439
440- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
440+ @pytest .mark .parametrize ("device" , all_supported_devices ())
441441 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
442442 def test_get_frame_at (self , device , seek_mode ):
443443 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -475,7 +475,7 @@ def test_get_frame_at(self, device, seek_mode):
475475 frame9 = decoder .get_frame_at (numpy .uint32 (9 ))
476476 assert_frames_equal (ref_frame9 , frame9 .data )
477477
478- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
478+ @pytest .mark .parametrize ("device" , all_supported_devices ())
479479 def test_get_frame_at_tuple_unpacking (self , device ):
480480 decoder = VideoDecoder (NASA_VIDEO .path , device = device )
481481
@@ -486,7 +486,7 @@ def test_get_frame_at_tuple_unpacking(self, device):
486486 assert frame .pts_seconds == pts
487487 assert frame .duration_seconds == duration
488488
489- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
489+ @pytest .mark .parametrize ("device" , all_supported_devices ())
490490 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
491491 def test_get_frame_at_fails (self , device , seek_mode ):
492492 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -500,7 +500,7 @@ def test_get_frame_at_fails(self, device, seek_mode):
500500 with pytest .raises (IndexError , match = "must be less than" ):
501501 frame = decoder .get_frame_at (10000 ) # noqa
502502
503- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
503+ @pytest .mark .parametrize ("device" , all_supported_devices ())
504504 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
505505 def test_get_frames_at (self , device , seek_mode ):
506506 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -551,7 +551,7 @@ def test_get_frames_at(self, device, seek_mode):
551551 frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
552552 )
553553
554- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
554+ @pytest .mark .parametrize ("device" , all_supported_devices ())
555555 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
556556 def test_get_frames_at_fails (self , device , seek_mode ):
557557 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -568,7 +568,7 @@ def test_get_frames_at_fails(self, device, seek_mode):
568568 with pytest .raises (RuntimeError , match = "Expected a value of type" ):
569569 decoder .get_frames_at ([0.3 ])
570570
571- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
571+ @pytest .mark .parametrize ("device" , all_supported_devices ())
572572 def test_get_frame_at_av1 (self , device ):
573573 if device == "cuda" and get_ffmpeg_major_version () == 4 :
574574 return
@@ -581,7 +581,7 @@ def test_get_frame_at_av1(self, device):
581581 assert decoded_frame10 .pts_seconds == ref_frame_info10 .pts_seconds
582582 assert_frames_equal (decoded_frame10 .data , ref_frame10 .to (device = device ))
583583
584- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
584+ @pytest .mark .parametrize ("device" , all_supported_devices ())
585585 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
586586 def test_get_frame_played_at (self , device , seek_mode ):
587587 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -610,7 +610,7 @@ def test_get_frame_played_at_h265(self):
610610 ref_frame6 = H265_VIDEO .get_frame_data_by_index (5 )
611611 assert_frames_equal (ref_frame6 , decoder .get_frame_played_at (0.5 ).data )
612612
613- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
613+ @pytest .mark .parametrize ("device" , all_supported_devices ())
614614 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
615615 def test_get_frame_played_at_fails (self , device , seek_mode ):
616616 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -621,7 +621,7 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
621621 with pytest .raises (IndexError , match = "Invalid pts in seconds" ):
622622 frame = decoder .get_frame_played_at (100.0 ) # noqa
623623
624- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
624+ @pytest .mark .parametrize ("device" , all_supported_devices ())
625625 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
626626 def test_get_frames_played_at (self , device , seek_mode ):
627627
@@ -660,7 +660,7 @@ def test_get_frames_played_at(self, device, seek_mode):
660660 frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
661661 )
662662
663- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
663+ @pytest .mark .parametrize ("device" , all_supported_devices ())
664664 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
665665 def test_get_frames_played_at_fails (self , device , seek_mode ):
666666 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -674,7 +674,7 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
674674 with pytest .raises (RuntimeError , match = "Expected a value of type" ):
675675 decoder .get_frames_played_at (["bad" ])
676676
677- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
677+ @pytest .mark .parametrize ("device" , all_supported_devices ())
678678 @pytest .mark .parametrize ("stream_index" , [0 , 3 , None ])
679679 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
680680 def test_get_frames_in_range (self , stream_index , device , seek_mode ):
@@ -779,7 +779,7 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
779779 empty_frames .duration_seconds , NASA_VIDEO .empty_duration_seconds
780780 )
781781
782- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
782+ @pytest .mark .parametrize ("device" , all_supported_devices ())
783783 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
784784 def test_get_frames_in_range_slice_indices_syntax (self , device , seek_mode ):
785785 decoder = VideoDecoder (
@@ -831,7 +831,7 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
831831 ).to (device )
832832 assert_frames_equal (frames387_None .data , reference_frame387_389 )
833833
834- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
834+ @pytest .mark .parametrize ("device" , all_supported_devices ())
835835 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
836836 @patch ("torchcodec._core._metadata._get_stream_json_metadata" )
837837 def test_get_frames_with_missing_num_frames_metadata (
@@ -894,7 +894,7 @@ def test_get_frames_with_missing_num_frames_metadata(
894894 lambda decoder : decoder .get_frames_played_in_range (0 , 1 ).data ,
895895 ),
896896 )
897- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
897+ @pytest .mark .parametrize ("device" , all_supported_devices ())
898898 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
899899 def test_dimension_order (self , dimension_order , frame_getter , device , seek_mode ):
900900 decoder = VideoDecoder (
@@ -922,7 +922,7 @@ def test_dimension_order_fails(self):
922922 VideoDecoder (NASA_VIDEO .path , dimension_order = "NCDHW" )
923923
924924 @pytest .mark .parametrize ("stream_index" , [0 , 3 , None ])
925- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
925+ @pytest .mark .parametrize ("device" , all_supported_devices ())
926926 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
927927 def test_get_frames_by_pts_in_range (self , stream_index , device , seek_mode ):
928928 decoder = VideoDecoder (
@@ -1061,7 +1061,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode):
10611061 )
10621062 assert_frames_equal (all_frames .data , decoder [:])
10631063
1064- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1064+ @pytest .mark .parametrize ("device" , all_supported_devices ())
10651065 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
10661066 def test_get_frames_by_pts_in_range_fails (self , device , seek_mode ):
10671067 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -1075,7 +1075,7 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
10751075 with pytest .raises (ValueError , match = "Invalid stop seconds" ):
10761076 frame = decoder .get_frames_played_in_range (0 , 23 ) # noqa
10771077
1078- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1078+ @pytest .mark .parametrize ("device" , all_supported_devices ())
10791079 def test_get_key_frame_indices (self , device ):
10801080 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = "exact" )
10811081 key_frame_indices = decoder ._get_key_frame_indices ()
@@ -1120,7 +1120,7 @@ def test_get_key_frame_indices(self, device):
11201120
11211121 # TODO investigate why this fails internally.
11221122 @pytest .mark .skipif (in_fbcode (), reason = "Compile test fails internally." )
1123- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1123+ @pytest .mark .parametrize ("device" , all_supported_devices ())
11241124 def test_compile (self , device ):
11251125 decoder = VideoDecoder (NASA_VIDEO .path , device = device )
11261126
0 commit comments