Skip to content

Commit 5a8e50b

Browse files
committed
Add tests
1 parent 5d0eeaf commit 5a8e50b

File tree

5 files changed

+60
-4
lines changed

5 files changed

+60
-4
lines changed

src/torchcodec/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# but that results in circular import.
99
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
1010
from . import decoders, samplers # noqa
11-
from .decoders._decoder_utils import set_cuda_backend # noqa
1211

1312
try:
1413
# Note that version.py is generated during install.

src/torchcodec/decoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .._core import AudioStreamMetadata, VideoStreamMetadata
88
from ._audio_decoder import AudioDecoder # noqa
9+
from ._decoder_utils import set_cuda_backend # noqa
910
from ._video_decoder import VideoDecoder # noqa
1011

1112
SimpleVideoDecoder = VideoDecoder

src/torchcodec/decoders/_decoder_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,5 +99,5 @@ def set_cuda_backend(backend: str) -> Generator[None, None, None]:
9999
_CUDA_BACKEND.reset(previous_state)
100100

101101

102-
def _get_current_cuda_backend() -> str:
102+
def _get_cuda_backend() -> str:
103103
return _CUDA_BACKEND.get()

src/torchcodec/decoders/_video_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from torchcodec import _core as core, Frame, FrameBatch
1717
from torchcodec.decoders._decoder_utils import (
18-
_get_current_cuda_backend,
18+
_get_cuda_backend,
1919
create_decoder,
2020
ERROR_REPORTING_INSTRUCTIONS,
2121
)
@@ -144,7 +144,7 @@ def __init__(
144144
if isinstance(device, torch_device):
145145
device = str(device)
146146

147-
device_variant = _get_current_cuda_backend()
147+
device_variant = _get_cuda_backend()
148148
if device_variant == "ffmpeg":
149149
# TODONVDEC P2 rename 'default' into 'ffmpeg' everywhere.
150150
device_variant = "default"

test/test_decoders.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from torchcodec.decoders import (
1919
AudioDecoder,
2020
AudioStreamMetadata,
21+
set_cuda_backend,
2122
VideoDecoder,
2223
VideoStreamMetadata,
2324
)
25+
from torchcodec.decoders._decoder_utils import _get_cuda_backend
2426

2527
from .utils import (
2628
all_supported_devices,
@@ -1705,6 +1707,60 @@ def test_beta_cuda_interface_error(self):
17051707
with pytest.raises(RuntimeError, match="Invalid device string"):
17061708
VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant")
17071709

1710+
@needs_cuda
1711+
def test_set_cuda_backend(self):
1712+
# Tests for the set_cuda_backend() context manager.
1713+
1714+
with pytest.raises(ValueError, match="Invalid CUDA backend"):
1715+
with set_cuda_backend("bad_backend"):
1716+
pass
1717+
1718+
# set_cuda_backend() is meant to be used as a context manager. Using it
1719+
# as a global call does nothing because the "context" is exited right
1720+
# away. This is a good thing, we prefer users to use it as a CM only.
1721+
set_cuda_backend("beta")
1722+
assert _get_cuda_backend() == "ffmpeg" # Not changed to "beta".
1723+
1724+
# Case insensitive
1725+
with set_cuda_backend("BETA"):
1726+
assert _get_cuda_backend() == "beta"
1727+
1728+
def assert_decoder_uses(decoder, *, expected_backend):
1729+
# Assert that a decoder instance is using a given backend.
1730+
#
1731+
# We know H265_VIDEO fails on the BETA backend while it works on the
1732+
# ffmpeg one.
1733+
if expected_backend == "ffmpeg":
1734+
decoder.get_frame_at(0) # this would fail if this was BETA
1735+
else:
1736+
with pytest.raises(RuntimeError, match="Video is too small"):
1737+
decoder.get_frame_at(0)
1738+
1739+
# Check that the default is the ffmpeg backend
1740+
assert _get_cuda_backend() == "ffmpeg"
1741+
dec = VideoDecoder(H265_VIDEO.path, device="cuda")
1742+
assert_decoder_uses(dec, expected_backend="ffmpeg")
1743+
1744+
# Check the setting "beta" effectively uses the BETA backend.
1745+
# We also show that the affects decoder creation only. When the decoder
1746+
# is created with a given backend, it stays in this backend for the rest
1747+
# of its life. This is normal and intended.
1748+
with set_cuda_backend("beta"):
1749+
dec = VideoDecoder(H265_VIDEO.path, device="cuda")
1750+
assert _get_cuda_backend() == "ffmpeg"
1751+
assert_decoder_uses(dec, expected_backend="beta")
1752+
with set_cuda_backend("ffmpeg"):
1753+
assert_decoder_uses(dec, expected_backend="beta")
1754+
1755+
# Hacky way to ensure passing "cuda:1" is supported by both backends. We
1756+
# just check that there's an error when passing cuda:N where N is too
1757+
# high.
1758+
bad_device_number = torch.cuda.device_count() + 1
1759+
for backend in ("ffmpeg", "beta"):
1760+
with pytest.raises(RuntimeError, match="invalid device ordinal"):
1761+
with set_cuda_backend(backend):
1762+
VideoDecoder(H265_VIDEO.path, device=f"cuda:{bad_device_number}")
1763+
17081764

17091765
class TestAudioDecoder:
17101766
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

0 commit comments

Comments
 (0)