Skip to content

Commit 3afc97f

Browse files
committed
Add test
1 parent 340974a commit 3afc97f

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

test/test_decoders.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,17 +1709,23 @@ def test_beta_cuda_interface_cpu_fallback(self):
17091709
# fallbacks to the CPU path in such cases. We assert that we fall back
17101710
# to the CPU path, too.
17111711

1712-
ffmpeg = VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0)
1713-
with set_cuda_backend("beta"):
1714-
beta = VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0)
1712+
ref_dec = VideoDecoder(H265_VIDEO.path, device="cuda")
1713+
ref_frames = ref_dec.get_frame_at(0)
1714+
assert (
1715+
_core._get_backend_details(ref_dec._decoder)
1716+
== "FFmpeg CUDA Device Interface. Using CPU fallback."
1717+
)
17151718

1716-
from torchvision.io import write_png
1717-
from torchvision.utils import make_grid
1719+
with set_cuda_backend("beta"):
1720+
beta_dec = VideoDecoder(H265_VIDEO.path, device="cuda")
17181721

1719-
write_png(make_grid([ffmpeg.data, beta.data], nrow=2).cpu(), "out.png")
1722+
assert (
1723+
_core._get_backend_details(beta_dec._decoder)
1724+
== "Beta CUDA Device Interface. Using CPU fallback."
1725+
)
1726+
beta_frame = beta_dec.get_frame_at(0)
17201727

1721-
assert psnr(ffmpeg.data.cpu(), beta.data.cpu()) > 25
1722-
# torch.testing.assert_close(ffmpeg.data, beta.data, rtol=0, atol=0)
1728+
assert psnr(ref_frames.data.cpu(), beta_frame.data.cpu()) > 25
17231729

17241730
@needs_cuda
17251731
def test_beta_cuda_interface_error(self):

0 commit comments

Comments
 (0)