1414import torch
1515
1616from torchcodec ._core import get_ffmpeg_library_versions
17+ from torchcodec .decoders import set_cuda_backend , VideoDecoder
1718from torchcodec .decoders ._video_decoder import _read_custom_frame_mappings
1819
1920IS_WINDOWS = sys .platform in ("win32" , "cygwin" )
@@ -26,27 +27,55 @@ def needs_cuda(test_item):
2627 return pytest .mark .needs_cuda (test_item )
2728
2829
30+ # This is a special device string that we use to test the "beta" CUDA backend.
31+ # It only exists here, in this test utils file. Public and core APIs have no
32+ # idea that this is how we're tesing them. That is, that's not a supported
33+ # `device` parameter for the VideoDecoder or for the _core APIs.
34+ # Tests using all_supported_devices() will get this device string, and the test
35+ # need to clean it up by calling either make_video_decoder for VideoDecoder, or
36+ # unsplit_device_str for core APIs.
37+ _CUDA_BETA_DEVICE_STR = "cuda:beta"
38+
39+
2940def all_supported_devices ():
3041 return (
3142 "cpu" ,
3243 pytest .param ("cuda" , marks = pytest .mark .needs_cuda ),
33- pytest .param ("cuda:0:beta" , marks = pytest .mark .needs_cuda ),
44+ pytest .param (_CUDA_BETA_DEVICE_STR , marks = pytest .mark .needs_cuda ),
3445 )
3546
3647
3748def unsplit_device_str (device_str : str ) -> str :
3849 # helper meant to be used as
3950 # device, device_variant = unsplit_device_str(device)
40- # when `device` comes from all_supported_devices() and may be "cuda:0:beta" .
51+ # when `device` comes from all_supported_devices() and may be _CUDA_BETA_DEVICE_STR .
4152 # It is used:
42- # - before calling `.to(device)` where device can't be "cuda:0:beta"
53+ # - before calling `.to(device)` where device can't be _CUDA_BETA_DEVICE_STR.
4354 # - before calling add_video_stream(device=device, device_variant=device_variant)
44- if device_str == "cuda:0:beta" :
55+ if device_str == _CUDA_BETA_DEVICE_STR :
4556 return "cuda" , "beta"
4657 else :
4758 return device_str , "ffmpeg"
4859
4960
61+ def make_video_decoder (* args , ** kwargs ) -> tuple [VideoDecoder , str ]:
62+ # Helper to create a VideoDecoder with the right cuda backend if needed.
63+ # kwargs is expected to have a "device" key which comes from
64+ # all_supported_devices(), and can be _CUDA_BETA_DEVICE_STR.
65+ device = kwargs .pop ("device" , "cpu" )
66+ if device == _CUDA_BETA_DEVICE_STR :
67+ clean_device , backend = "cuda" , "beta"
68+ else :
69+ clean_device , backend = device , "ffmpeg"
70+
71+ # set_cuda_backend is a no-op if the device is "cpu", so we can use it
72+ # unconditionally.
73+ with set_cuda_backend (backend ):
74+ dec = VideoDecoder (* args , ** kwargs , device = clean_device )
75+
76+ return dec , clean_device
77+
78+
5079def get_ffmpeg_major_version ():
5180 ffmpeg_version = get_ffmpeg_library_versions ()["ffmpeg_version" ]
5281 # When building FFmpeg from source there can be a `n` prefix in the version
0 commit comments