Skip to content

Commit b32e6f3

Browse files
author
Molly Xu
committed
expose cpu_fallback
1 parent 7dbbf83 commit b32e6f3

File tree

3 files changed

+172
-1
lines changed

3 files changed

+172
-1
lines changed

src/torchcodec/decoders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
from .._core import AudioStreamMetadata, VideoStreamMetadata
88
from ._audio_decoder import AudioDecoder # noqa
99
from ._decoder_utils import set_cuda_backend # noqa
10-
from ._video_decoder import VideoDecoder # noqa
10+
from ._video_decoder import FallbackInfo, VideoDecoder # noqa
1111

1212
SimpleVideoDecoder = VideoDecoder

src/torchcodec/decoders/_video_decoder.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import io
88
import json
99
import numbers
10+
from dataclasses import dataclass
1011
from pathlib import Path
1112
from typing import List, Literal, Optional, Sequence, Tuple, Union
1213

@@ -22,6 +23,48 @@
2223
from torchcodec.transforms import DecoderTransform, Resize
2324

2425

26+
@dataclass
27+
class FallbackInfo:
28+
"""Information about decoder fallback status.
29+
30+
This class tracks whether hardware-accelerated decoding failed and the decoder
31+
fell back to software decoding.
32+
33+
Usage:
34+
- Use ``str(fallback_info)`` or ``print(fallback_info)`` to see the fallback status
35+
- Use ``bool(fallback_info)`` to check if any fallback occurred
36+
37+
Attributes:
38+
status_known (bool): Whether the fallback status has been determined.
39+
"""
40+
41+
def __init__(self):
42+
self.status_known = False
43+
self.__nvcuvid_unavailable = False
44+
self.__video_not_supported = False
45+
46+
def __bool__(self):
47+
"""Returns True if fallback occurred (and status is known)."""
48+
return self.status_known and (
49+
self.__nvcuvid_unavailable or self.__video_not_supported
50+
)
51+
52+
def __str__(self):
53+
"""Returns a human-readable string representation of the fallback status."""
54+
if not self.status_known:
55+
return "Fallback status: Unknown"
56+
57+
reasons = []
58+
if self.__nvcuvid_unavailable:
59+
reasons.append("NVcuvid unavailable")
60+
if self.__video_not_supported:
61+
reasons.append("Video not supported")
62+
63+
if reasons:
64+
return "Fallback status: Falling back due to: " + ", ".join(reasons)
65+
return "Fallback status: No fallback required"
66+
67+
2568
class VideoDecoder:
2669
"""A single-stream video decoder.
2770
@@ -180,13 +223,48 @@ def __init__(
180223
custom_frame_mappings=custom_frame_mappings_data,
181224
)
182225

226+
# Initialize fallback info
227+
self._fallback_info = FallbackInfo()
228+
183229
def __len__(self) -> int:
184230
return self._num_frames
185231

232+
@property
233+
def cpu_fallback(self) -> FallbackInfo:
234+
"""Get information about decoder fallback status.
235+
236+
Returns:
237+
FallbackInfo: Information about whether hardware-accelerated decoding
238+
failed and the decoder fell back to software decoding.
239+
240+
Note:
241+
The fallback status is only determined after the first frame access.
242+
Before that, the status will be "Unknown".
243+
"""
244+
return self._fallback_info
245+
246+
def _update_cpu_fallback(self):
247+
"""Update the fallback status if it hasn't been determined yet.
248+
249+
This method should be called after any frame decoding operation to determine
250+
if fallback to software decoding occurred.
251+
"""
252+
if not self._fallback_info.status_known:
253+
backend_details = core._get_backend_details(self._decoder)
254+
255+
self._fallback_info.status_known = True
256+
257+
if "CPU fallback" in backend_details:
258+
if "NVCUVID not available" in backend_details:
259+
self._fallback_info._FallbackInfo__nvcuvid_unavailable = True
260+
else:
261+
self._fallback_info._FallbackInfo__video_not_supported = True
262+
186263
def _getitem_int(self, key: int) -> Tensor:
187264
assert isinstance(key, int)
188265

189266
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
267+
self._update_cpu_fallback()
190268
return frame_data
191269

192270
def _getitem_slice(self, key: slice) -> Tensor:
@@ -199,6 +277,7 @@ def _getitem_slice(self, key: slice) -> Tensor:
199277
stop=stop,
200278
step=step,
201279
)
280+
self._update_cpu_fallback()
202281
return frame_data
203282

204283
def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
@@ -252,6 +331,7 @@ def get_frame_at(self, index: int) -> Frame:
252331
data, pts_seconds, duration_seconds = core.get_frame_at_index(
253332
self._decoder, frame_index=index
254333
)
334+
self._update_cpu_fallback()
255335
return Frame(
256336
data=data,
257337
pts_seconds=pts_seconds.item(),
@@ -271,6 +351,7 @@ def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch:
271351
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
272352
self._decoder, frame_indices=indices
273353
)
354+
self._update_cpu_fallback()
274355

275356
return FrameBatch(
276357
data=data,
@@ -300,6 +381,7 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
300381
stop=stop,
301382
step=step,
302383
)
384+
self._update_cpu_fallback()
303385
return FrameBatch(*frames)
304386

305387
def get_frame_played_at(self, seconds: float) -> Frame:
@@ -329,6 +411,7 @@ def get_frame_played_at(self, seconds: float) -> Frame:
329411
data, pts_seconds, duration_seconds = core.get_frame_at_pts(
330412
self._decoder, seconds
331413
)
414+
self._update_cpu_fallback()
332415
return Frame(
333416
data=data,
334417
pts_seconds=pts_seconds.item(),
@@ -350,6 +433,7 @@ def get_frames_played_at(
350433
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
351434
self._decoder, timestamps=seconds
352435
)
436+
self._update_cpu_fallback()
353437
return FrameBatch(
354438
data=data,
355439
pts_seconds=pts_seconds,
@@ -394,6 +478,7 @@ def get_frames_played_in_range(
394478
start_seconds=start_seconds,
395479
stop_seconds=stop_seconds,
396480
)
481+
self._update_cpu_fallback()
397482
return FrameBatch(*frames)
398483

399484

test/test_decoders.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,6 +1737,92 @@ def test_set_cuda_backend(self):
17371737
with set_cuda_backend(backend):
17381738
VideoDecoder(H265_VIDEO.path, device=f"cuda:{bad_device_number}")
17391739

1740+
def test_cpu_fallback_before_after_decoding(self):
1741+
decoder = VideoDecoder(NASA_VIDEO.path)
1742+
1743+
# Before accessing any frames, status should be unknown
1744+
assert not decoder.cpu_fallback.status_known
1745+
assert str(decoder.cpu_fallback) == "Fallback status: Unknown"
1746+
assert not bool(decoder.cpu_fallback)
1747+
1748+
# After accessing frames, status should be known
1749+
_ = decoder[0]
1750+
assert decoder.cpu_fallback.status_known
1751+
assert str(decoder.cpu_fallback) != "Fallback status: Unknown"
1752+
1753+
def test_cpu_fallback_no_fallback_on_cpu_device(self):
1754+
"""Test that CPU device doesn't trigger fallback (it's not a fallback scenario)."""
1755+
decoder = VideoDecoder(NASA_VIDEO.path, device="cpu")
1756+
1757+
_ = decoder[0]
1758+
1759+
assert decoder.cpu_fallback.status_known
1760+
assert not bool(decoder.cpu_fallback)
1761+
assert "No fallback required" in str(decoder.cpu_fallback)
1762+
1763+
@needs_cuda
1764+
def test_cpu_fallback_h265_video_ffmpeg_cuda(self):
1765+
"""Test that H265 video triggers CPU fallback on FFmpeg CUDA interface."""
1766+
# H265_VIDEO is known to trigger CPU fallback on FFmpeg CUDA
1767+
# because its dimensions are too small
1768+
decoder = VideoDecoder(H265_VIDEO.path, device="cuda")
1769+
1770+
_ = decoder.get_frame_at(0)
1771+
1772+
assert decoder.cpu_fallback.status_known
1773+
assert bool(decoder.cpu_fallback)
1774+
assert "Fallback status: Falling back due to:" in str(decoder.cpu_fallback)
1775+
1776+
@needs_cuda
1777+
def test_cpu_fallback_h265_video_beta_cuda(self):
1778+
"""Test that H265 video triggers CPU fallback on Beta CUDA interface."""
1779+
with set_cuda_backend("beta"):
1780+
decoder = VideoDecoder(H265_VIDEO.path, device="cuda")
1781+
1782+
_ = decoder.get_frame_at(0)
1783+
1784+
assert decoder.cpu_fallback.status_known
1785+
assert bool(decoder.cpu_fallback)
1786+
assert "Fallback status: Falling back due to:" in str(decoder.cpu_fallback)
1787+
1788+
@needs_cuda
1789+
def test_cpu_fallback_no_fallback_on_supported_video(self):
1790+
"""Test that supported videos don't trigger fallback on CUDA."""
1791+
decoder = VideoDecoder(NASA_VIDEO.path, device="cuda")
1792+
1793+
# Access a frame to determine status
1794+
_ = decoder[0]
1795+
1796+
assert not bool(decoder.cpu_fallback)
1797+
1798+
def test_cpu_fallback_status_cached(self):
1799+
"""Test that cpu_fallback status is determined once and then cached."""
1800+
decoder = VideoDecoder(NASA_VIDEO.path)
1801+
1802+
_ = decoder[0]
1803+
first_status = str(decoder.cpu_fallback)
1804+
assert decoder.cpu_fallback.status_known
1805+
1806+
_ = decoder[1]
1807+
second_status = str(decoder.cpu_fallback)
1808+
assert decoder.cpu_fallback.status_known
1809+
1810+
assert first_status == second_status
1811+
1812+
def test_cpu_fallback_multiple_access_methods(self):
1813+
"""Test that cpu_fallback works with different frame access methods."""
1814+
decoder = VideoDecoder(NASA_VIDEO.path)
1815+
1816+
_ = decoder.get_frame_at(0)
1817+
assert decoder.cpu_fallback.status_known
1818+
status_after_get_frame = str(decoder.cpu_fallback)
1819+
1820+
_ = decoder.get_frames_in_range(1, 3)
1821+
assert str(decoder.cpu_fallback) == status_after_get_frame
1822+
1823+
_ = decoder.get_frame_played_at(0.5)
1824+
assert str(decoder.cpu_fallback) == status_after_get_frame
1825+
17401826

17411827
class TestAudioDecoder:
17421828
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

0 commit comments

Comments
 (0)