Skip to content

Commit a661328

Browse files
committed
Add set_cuda_backend Context Manager to publicly expose the BETA CUDA interface
1 parent 61202b9 commit a661328

File tree

4 files changed

+66
-13
lines changed

4 files changed

+66
-13
lines changed

src/torchcodec/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# but that results in circular import.
99
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
1010
from . import decoders, samplers # noqa
11+
from .decoders._decoder_utils import set_cuda_backend # noqa
1112

1213
try:
1314
# Note that version.py is generated during install.

src/torchcodec/decoders/_decoder_utils.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextvars
78
import io
9+
from contextlib import contextmanager
810
from pathlib import Path
911

10-
from typing import Union
12+
from typing import Generator, Optional, Union
1113

1214
from torch import Tensor
1315
from torchcodec import _core as core
@@ -50,3 +52,52 @@ def create_decoder(
5052
"read(self, size: int) -> bytes and "
5153
"seek(self, offset: int, whence: int) -> int methods."
5254
)
55+
56+
57+
# Thread-local and async-safe storage for the current CUDA backend
58+
_CUDA_BACKEND: contextvars.ContextVar[str] = contextvars.ContextVar(
59+
"_CUDA_BACKEND", default="ffmpeg"
60+
)
61+
62+
63+
@contextmanager
64+
def set_cuda_backend(backend: str) -> Generator[None, None, None]:
65+
"""Context Manager to set the CUDA backend for :class:`~torchcodec.decoders.VideoDecoder`.
66+
67+
This context manager allows you to specify which CUDA backend implementation
68+
to use when creating :class:`~torchcodec.decoders.VideoDecoder` instances
69+
with CUDA devices. This is thread-safe and async-safe.
70+
71+
Note that you still need to pass `device="cuda"` when creating the
72+
:class:`~torchcodec.decoders.VideoDecoder` instance. If a CUDA device isn't
73+
specified, this context manager will have no effect.
74+
75+
Only the creation of the decoder needs to be inside the context manager, the
76+
decoding methods can be called outside of it.
77+
78+
Args:
79+
backend (str): The CUDA backend to use. Can be "ffmpeg" or "beta". Default is "ffmpeg".
80+
81+
Example:
82+
>>> with torchcodec.set_cuda_backend("beta"):
83+
... decoder = VideoDecoder("video.mp4", device="cuda")
84+
...
85+
... # Only the decoder creation needs to be part of the context manager.
86+
... # Decoder will now the beta CUDA implementation:
87+
... decoder.get_frame_at(0)
88+
"""
89+
backend = backend.lower()
90+
if backend not in ("ffmpeg", "beta"):
91+
raise ValueError(
92+
f"Invalid CUDA backend ({backend}). Supported values are 'ffmpeg' and 'beta'."
93+
)
94+
95+
previous_state = _CUDA_BACKEND.set(backend)
96+
try:
97+
yield
98+
finally:
99+
_CUDA_BACKEND.reset(previous_state)
100+
101+
102+
def _get_current_cuda_backend() -> str:
103+
return _CUDA_BACKEND.get()

src/torchcodec/decoders/_video_decoder.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from torchcodec import _core as core, Frame, FrameBatch
1717
from torchcodec.decoders._decoder_utils import (
18+
_get_current_cuda_backend,
1819
create_decoder,
1920
ERROR_REPORTING_INSTRUCTIONS,
2021
)
@@ -143,17 +144,17 @@ def __init__(
143144
if isinstance(device, torch_device):
144145
device = str(device)
145146

146-
# If device looks like "cuda:0:beta", make it "cuda:0" and set
147-
# device_variant to "beta"
148-
# TODONVDEC P2 Consider alternative ways of exposing custom device
149-
# variants, and if we want this new decoder backend to be a "device
150-
# variant" at all.
151-
device_variant = "default"
152-
if device is not None:
153-
device_split = device.split(":")
154-
if len(device_split) == 3:
155-
device_variant = device_split[2]
156-
device = ":".join(device_split[0:2])
147+
device_variant = _get_current_cuda_backend()
148+
if device_variant == "ffmpeg":
149+
# TODONVDEC P2 rename 'default' into 'ffmpeg' everywhere.
150+
device_variant = "default"
151+
152+
# Legacy support for device="cuda:0:beta" syntax
153+
# TODONVDEC P2: remove support for this everywhere. This will require
154+
# updating our tests.
155+
if device == "cuda:0:beta":
156+
device = "cuda:0"
157+
device_variant = "beta"
157158

158159
core.add_video_stream(
159160
self._decoder,

test/test_decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1702,7 +1702,7 @@ def test_beta_cuda_interface_small_h265(self):
17021702

17031703
@needs_cuda
17041704
def test_beta_cuda_interface_error(self):
1705-
with pytest.raises(RuntimeError, match="Unsupported device"):
1705+
with pytest.raises(RuntimeError, match="Invalid device string"):
17061706
VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant")
17071707

17081708

0 commit comments

Comments
 (0)