@@ -1378,9 +1378,9 @@ def test_bad_input(self, tmp_path):
13781378 filename = "./bad/path.mp3" ,
13791379 )
13801380
1381- def decode (self , file_path ) -> torch .Tensor :
1381+ def decode (self , file_path , device = "cpu" ) -> torch .Tensor :
13821382 decoder = create_from_file (str (file_path ), seek_mode = "approximate" )
1383- add_video_stream (decoder )
1383+ add_video_stream (decoder , device = device )
13841384 frames , * _ = get_frames_in_range (decoder , start = 0 , stop = 60 )
13851385 return frames
13861386
@@ -1399,14 +1399,14 @@ def test_video_encoder_round_trip(self, tmp_path, format):
13991399 ):
14001400 pytest .skip ("Codec for webm is not available in this FFmpeg installation." )
14011401 asset = TEST_SRC_2_720P
1402- source_frames = self .decode (str (asset .path )).data
1402+ source_frames = self .decode (str (asset .path ), device = "cpu" ).data
14031403
14041404 encoded_path = str (tmp_path / f"encoder_output.{ format } " )
14051405 frame_rate = 30 # Frame rate is fixed with num frames decoded
14061406 encode_video_to_file (
14071407 frames = source_frames , frame_rate = frame_rate , filename = encoded_path , crf = 0
14081408 )
1409- round_trip_frames = self .decode (encoded_path ).data
1409+ round_trip_frames = self .decode (encoded_path , device = "cpu" ).data
14101410 assert source_frames .shape == round_trip_frames .shape
14111411 assert source_frames .dtype == round_trip_frames .dtype
14121412
0 commit comments