77import io
88import json
99import numbers
10+ from dataclasses import dataclass
1011from pathlib import Path
1112from typing import List , Literal , Optional , Sequence , Tuple , Union
1213
2223from torchcodec .transforms import DecoderTransform , Resize
2324
2425
26+ @dataclass
27+ class FallbackInfo :
28+ """Information about decoder fallback status.
29+
30+ This class tracks whether hardware-accelerated decoding failed and the decoder
31+ fell back to software decoding.
32+
33+ Usage:
34+ - Use ``str(fallback_info)`` or ``print(fallback_info)`` to see the fallback status
35+ - Use ``bool(fallback_info)`` to check if any fallback occurred
36+
37+ Attributes:
38+ status_known (bool): Whether the fallback status has been determined.
39+ """
40+
41+ def __init__ (self ):
42+ self .status_known = False
43+ self .__nvcuvid_unavailable = False
44+ self .__video_not_supported = False
45+
46+ def __bool__ (self ):
47+ """Returns True if fallback occurred (and status is known)."""
48+ return self .status_known and (
49+ self .__nvcuvid_unavailable or self .__video_not_supported
50+ )
51+
52+ def __str__ (self ):
53+ """Returns a human-readable string representation of the fallback status."""
54+ if not self .status_known :
55+ return "Fallback status: Unknown"
56+
57+ reasons = []
58+ if self .__nvcuvid_unavailable :
59+ reasons .append ("NVcuvid unavailable" )
60+ if self .__video_not_supported :
61+ reasons .append ("Video not supported" )
62+
63+ if reasons :
64+ return "Fallback status: Falling back due to: " + ", " .join (reasons )
65+ return "Fallback status: No fallback required"
66+
67+
2568class VideoDecoder :
2669 """A single-stream video decoder.
2770
@@ -180,13 +223,48 @@ def __init__(
180223 custom_frame_mappings = custom_frame_mappings_data ,
181224 )
182225
226+ # Initialize fallback info
227+ self ._fallback_info = FallbackInfo ()
228+
183229 def __len__ (self ) -> int :
184230 return self ._num_frames
185231
232+ @property
233+ def cpu_fallback (self ) -> FallbackInfo :
234+ """Get information about decoder fallback status.
235+
236+ Returns:
237+ FallbackInfo: Information about whether hardware-accelerated decoding
238+ failed and the decoder fell back to software decoding.
239+
240+ Note:
241+ The fallback status is only determined after the first frame access.
242+ Before that, the status will be "Unknown".
243+ """
244+ return self ._fallback_info
245+
246+ def _update_cpu_fallback (self ):
247+ """Update the fallback status if it hasn't been determined yet.
248+
249+ This method should be called after any frame decoding operation to determine
250+ if fallback to software decoding occurred.
251+ """
252+ if not self ._fallback_info .status_known :
253+ backend_details = core ._get_backend_details (self ._decoder )
254+
255+ self ._fallback_info .status_known = True
256+
257+ if "CPU fallback" in backend_details :
258+ if "NVCUVID not available" in backend_details :
259+ self ._fallback_info ._FallbackInfo__nvcuvid_unavailable = True
260+ else :
261+ self ._fallback_info ._FallbackInfo__video_not_supported = True
262+
186263 def _getitem_int (self , key : int ) -> Tensor :
187264 assert isinstance (key , int )
188265
189266 frame_data , * _ = core .get_frame_at_index (self ._decoder , frame_index = key )
267+ self ._update_cpu_fallback ()
190268 return frame_data
191269
192270 def _getitem_slice (self , key : slice ) -> Tensor :
@@ -199,6 +277,7 @@ def _getitem_slice(self, key: slice) -> Tensor:
199277 stop = stop ,
200278 step = step ,
201279 )
280+ self ._update_cpu_fallback ()
202281 return frame_data
203282
204283 def __getitem__ (self , key : Union [numbers .Integral , slice ]) -> Tensor :
@@ -252,6 +331,7 @@ def get_frame_at(self, index: int) -> Frame:
252331 data , pts_seconds , duration_seconds = core .get_frame_at_index (
253332 self ._decoder , frame_index = index
254333 )
334+ self ._update_cpu_fallback ()
255335 return Frame (
256336 data = data ,
257337 pts_seconds = pts_seconds .item (),
@@ -271,6 +351,7 @@ def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch:
271351 data , pts_seconds , duration_seconds = core .get_frames_at_indices (
272352 self ._decoder , frame_indices = indices
273353 )
354+ self ._update_cpu_fallback ()
274355
275356 return FrameBatch (
276357 data = data ,
@@ -300,6 +381,7 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
300381 stop = stop ,
301382 step = step ,
302383 )
384+ self ._update_cpu_fallback ()
303385 return FrameBatch (* frames )
304386
305387 def get_frame_played_at (self , seconds : float ) -> Frame :
@@ -329,6 +411,7 @@ def get_frame_played_at(self, seconds: float) -> Frame:
329411 data , pts_seconds , duration_seconds = core .get_frame_at_pts (
330412 self ._decoder , seconds
331413 )
414+ self ._update_cpu_fallback ()
332415 return Frame (
333416 data = data ,
334417 pts_seconds = pts_seconds .item (),
@@ -350,6 +433,7 @@ def get_frames_played_at(
350433 data , pts_seconds , duration_seconds = core .get_frames_by_pts (
351434 self ._decoder , timestamps = seconds
352435 )
436+ self ._update_cpu_fallback ()
353437 return FrameBatch (
354438 data = data ,
355439 pts_seconds = pts_seconds ,
@@ -394,6 +478,7 @@ def get_frames_played_in_range(
394478 start_seconds = start_seconds ,
395479 stop_seconds = stop_seconds ,
396480 )
481+ self ._update_cpu_fallback ()
397482 return FrameBatch (* frames )
398483
399484
0 commit comments