diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 0e20c5e8d..67c274136 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -241,6 +241,8 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( std::optional preAllocatedOutputTensor) { validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame); + hasDecodedFrame_ = true; + // All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by // converting them to NV12. avFrame = maybeConvertAVFrameToNV12OrRGB24(avFrame); @@ -358,6 +360,10 @@ std::string CudaDeviceInterface::getDetails() { // Note: for this interface specifically the fallback is only known after a // frame has been decoded, not before: that's when FFmpeg decides to fallback, // so we can't know earlier. + if (!hasDecodedFrame_) { + return std::string( + "FFmpeg CUDA Device Interface. Fallback status unknown (no frames decoded)."); + } return std::string("FFmpeg CUDA Device Interface. Using ") + (usingCPUFallback_ ? "CPU fallback." : "NVDEC."); } diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index c892bd49b..90d359185 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -63,6 +63,7 @@ class CudaDeviceInterface : public DeviceInterface { std::unique_ptr nv12Conversion_; bool usingCPUFallback_ = false; + bool hasDecodedFrame_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/__init__.py b/src/torchcodec/decoders/__init__.py index 980ba98a9..ef08cce83 100644 --- a/src/torchcodec/decoders/__init__.py +++ b/src/torchcodec/decoders/__init__.py @@ -7,6 +7,6 @@ from .._core import AudioStreamMetadata, VideoStreamMetadata from ._audio_decoder import AudioDecoder # noqa from ._decoder_utils import set_cuda_backend # noqa -from ._video_decoder import VideoDecoder # noqa +from ._video_decoder import CpuFallbackStatus, VideoDecoder # noqa SimpleVideoDecoder = VideoDecoder diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 9c2727bad..2f91878ca 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -7,6 +7,7 @@ import io import json import numbers +from dataclasses import dataclass from pathlib import Path from typing import Literal, Optional, Sequence, Tuple, Union @@ -23,6 +24,51 @@ from torchcodec.transforms._decoder_transforms import _make_transform_specs +@dataclass +class CpuFallbackStatus: + """Information about CPU fallback status. + + This class tracks whether the decoder fell back to CPU decoding. + + Usage: + - Use ``str(cpu_fallback_status)`` or ``print(cpu_fallback_status)`` to see the cpu fallback status + - Use ``bool(cpu_fallback_status)`` to check if any fallback occurred + + Attributes: + status_known (bool): Whether the fallback status has been determined. + """ + + def __init__(self): + self.status_known = False + self._nvcuvid_unavailable = False + self._video_not_supported = False + self._backend = "" + + def __bool__(self): + """Returns True if fallback occurred.""" + return self.status_known and ( + self._nvcuvid_unavailable or self._video_not_supported + ) + + def __str__(self): + """Returns a human-readable string representation of the cpu fallback status.""" + if not self.status_known: + return "Fallback status: Unknown" + + reasons = [] + if self._nvcuvid_unavailable: + reasons.append("NVcuvid unavailable") + if self._video_not_supported: + reasons.append("Video not supported") + + if reasons: + return ( + f"[{self._backend}] Fallback status: Falling back due to: " + + ", ".join(reasons) + ) + return f"[{self._backend}] Fallback status: No fallback required" + + class VideoDecoder: """A single-stream video decoder. @@ -101,6 +147,10 @@ class VideoDecoder: stream_index (int): The stream index that this decoder is retrieving frames from. If a stream index was provided at initialization, this is the same value. If it was left unspecified, this is the :term:`best stream`. + cpu_fallback (CpuFallbackStatus): Information about whether the decoder fell back to CPU + decoding. Use ``bool(cpu_fallback)`` to check if fallback occurred, or + ``str(cpu_fallback)`` to get a human-readable status message. The status is only + determined after at least one frame has been decoded. """ def __init__( @@ -184,9 +234,34 @@ def __init__( custom_frame_mappings=custom_frame_mappings_data, ) + self._cpu_fallback = CpuFallbackStatus() + def __len__(self) -> int: return self._num_frames + @property + def cpu_fallback(self) -> CpuFallbackStatus: + # We can only determine whether fallback to CPU is happening when this + # property is accessed and requires that at least one frame has been decoded. + if not self._cpu_fallback.status_known: + backend_details = core._get_backend_details(self._decoder) + + if "status unknown" not in backend_details: + self._cpu_fallback.status_known = True + + for backend in ("FFmpeg CUDA", "Beta CUDA", "CPU"): + if backend_details.startswith(backend): + self._cpu_fallback._backend = backend + break + + if "CPU fallback" in backend_details: + if "NVCUVID not available" in backend_details: + self._cpu_fallback._nvcuvid_unavailable = True + else: + self._cpu_fallback._video_not_supported = True + + return self._cpu_fallback + def _getitem_int(self, key: int) -> Tensor: assert isinstance(key, int) diff --git a/test/test_decoders.py b/test/test_decoders.py index efa2d11c8..b56d70290 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1672,22 +1672,27 @@ def test_beta_cuda_interface_cpu_fallback(self): # to the CPU path, too. ref_dec = VideoDecoder(H265_VIDEO.path, device="cuda") - ref_frames = ref_dec.get_frame_at(0) - assert ( - _core._get_backend_details(ref_dec._decoder) - == "FFmpeg CUDA Device Interface. Using CPU fallback." - ) + + # Before accessing any frames, status should be unknown + assert not ref_dec.cpu_fallback.status_known + + ref_frame = ref_dec.get_frame_at(0) + + assert "FFmpeg CUDA" in str(ref_dec.cpu_fallback) + assert ref_dec.cpu_fallback.status_known + assert bool(ref_dec.cpu_fallback) with set_cuda_backend("beta"): beta_dec = VideoDecoder(H265_VIDEO.path, device="cuda") - assert ( - _core._get_backend_details(beta_dec._decoder) - == "Beta CUDA Device Interface. Using CPU fallback." - ) + assert "Beta CUDA" in str(beta_dec.cpu_fallback) + # For beta interface, status is known immediately + assert beta_dec.cpu_fallback.status_known + assert bool(beta_dec.cpu_fallback) + beta_frame = beta_dec.get_frame_at(0) - assert psnr(ref_frames.data, beta_frame.data) > 25 + assert psnr(ref_frame.data, beta_frame.data) > 25 @needs_cuda def test_beta_cuda_interface_error(self): @@ -1715,7 +1720,8 @@ def test_set_cuda_backend(self): # Check that the default is the ffmpeg backend assert _get_cuda_backend() == "ffmpeg" dec = VideoDecoder(H265_VIDEO.path, device="cuda") - assert _core._get_backend_details(dec._decoder).startswith("FFmpeg CUDA") + _ = dec.get_frame_at(0) + assert "FFmpeg CUDA" in str(dec.cpu_fallback) # Check the setting "beta" effectively uses the BETA backend. # We also show that the affects decoder creation only. When the decoder @@ -1724,9 +1730,9 @@ def test_set_cuda_backend(self): with set_cuda_backend("beta"): dec = VideoDecoder(H265_VIDEO.path, device="cuda") assert _get_cuda_backend() == "ffmpeg" - assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA") + assert "Beta CUDA" in str(dec.cpu_fallback) with set_cuda_backend("ffmpeg"): - assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA") + assert "Beta CUDA" in str(dec.cpu_fallback) # Hacky way to ensure passing "cuda:1" is supported by both backends. We # just check that there's an error when passing cuda:N where N is too @@ -1737,6 +1743,69 @@ def test_set_cuda_backend(self): with set_cuda_backend(backend): VideoDecoder(H265_VIDEO.path, device=f"cuda:{bad_device_number}") + def test_cpu_fallback_no_fallback_on_cpu_device(self): + """Test that CPU device doesn't trigger fallback (it's not a fallback scenario).""" + decoder = VideoDecoder(NASA_VIDEO.path, device="cpu") + + _ = decoder[0] + + assert decoder.cpu_fallback.status_known + assert not bool(decoder.cpu_fallback) + assert "No fallback required" in str(decoder.cpu_fallback) + + @needs_cuda + def test_cpu_fallback_h265_video_ffmpeg_cuda(self): + """Test that H265 video triggers CPU fallback on FFmpeg CUDA interface.""" + # H265_VIDEO is known to trigger CPU fallback on FFmpeg CUDA + # because its dimensions are too small + decoder = VideoDecoder(H265_VIDEO.path, device="cuda") + + assert not decoder.cpu_fallback.status_known + + _ = decoder.get_frame_at(0) + + assert decoder.cpu_fallback.status_known + assert bool(decoder.cpu_fallback) + assert "Fallback status: Falling back due to:" in str(decoder.cpu_fallback) + + @needs_cuda + def test_cpu_fallback_no_fallback_on_supported_video(self): + """Test that supported videos don't trigger fallback on CUDA.""" + decoder = VideoDecoder(NASA_VIDEO.path, device="cuda") + + _ = decoder[0] + + assert not bool(decoder.cpu_fallback) + assert "No fallback required" in str(decoder.cpu_fallback) + + def test_cpu_fallback_status_cached(self): + """Test that cpu_fallback status is determined once and then cached.""" + decoder = VideoDecoder(NASA_VIDEO.path) + + _ = decoder[0] + first_status = str(decoder.cpu_fallback) + assert decoder.cpu_fallback.status_known + + _ = decoder[1] + second_status = str(decoder.cpu_fallback) + assert decoder.cpu_fallback.status_known + + assert first_status == second_status + + def test_cpu_fallback_multiple_access_methods(self): + """Test that cpu_fallback works with different frame access methods.""" + decoder = VideoDecoder(NASA_VIDEO.path) + + _ = decoder.get_frame_at(0) + assert decoder.cpu_fallback.status_known + status_after_get_frame = str(decoder.cpu_fallback) + + _ = decoder.get_frames_in_range(1, 3) + assert str(decoder.cpu_fallback) == status_after_get_frame + + _ = decoder.get_frame_played_at(0.5) + assert str(decoder.cpu_fallback) == status_after_get_frame + class TestAudioDecoder: @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))