@@ -1709,17 +1709,23 @@ def test_beta_cuda_interface_cpu_fallback(self):
17091709 # fallbacks to the CPU path in such cases. We assert that we fall back
17101710 # to the CPU path, too.
17111711
1712- ffmpeg = VideoDecoder (H265_VIDEO .path , device = "cuda" ).get_frame_at (0 )
1713- with set_cuda_backend ("beta" ):
1714- beta = VideoDecoder (H265_VIDEO .path , device = "cuda" ).get_frame_at (0 )
1712+ ref_dec = VideoDecoder (H265_VIDEO .path , device = "cuda" )
1713+ ref_frames = ref_dec .get_frame_at (0 )
1714+ assert (
1715+ _core ._get_backend_details (ref_dec ._decoder )
1716+ == "FFmpeg CUDA Device Interface. Using CPU fallback."
1717+ )
17151718
1716- from torchvision . io import write_png
1717- from torchvision . utils import make_grid
1719+ with set_cuda_backend ( "beta" ):
1720+ beta_dec = VideoDecoder ( H265_VIDEO . path , device = "cuda" )
17181721
1719- write_png (make_grid ([ffmpeg .data , beta .data ], nrow = 2 ).cpu (), "out.png" )
1722+ assert (
1723+ _core ._get_backend_details (beta_dec ._decoder )
1724+ == "Beta CUDA Device Interface. Using CPU fallback."
1725+ )
1726+ beta_frame = beta_dec .get_frame_at (0 )
17201727
1721- assert psnr (ffmpeg .data .cpu (), beta .data .cpu ()) > 25
1722- # torch.testing.assert_close(ffmpeg.data, beta.data, rtol=0, atol=0)
1728+ assert psnr (ref_frames .data .cpu (), beta_frame .data .cpu ()) > 25
17231729
17241730 @needs_cuda
17251731 def test_beta_cuda_interface_error (self ):
0 commit comments