Skip to content

Commit 5ac8321

Browse files
author
Molly Xu
committed
address feedback:
1 parent 6e69c8c commit 5ac8321

File tree

5 files changed

+40
-51
lines changed

5 files changed

+40
-51
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
241241
std::optional<torch::Tensor> preAllocatedOutputTensor) {
242242
validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame);
243243

244+
hasDecodedFrame_ = true;
245+
244246
// All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by
245247
// converting them to NV12.
246248
avFrame = maybeConvertAVFrameToNV12OrRGB24(avFrame);
@@ -358,6 +360,10 @@ std::string CudaDeviceInterface::getDetails() {
358360
// Note: for this interface specifically the fallback is only known after a
359361
// frame has been decoded, not before: that's when FFmpeg decides to fallback,
360362
// so we can't know earlier.
363+
if (!hasDecodedFrame_) {
364+
return std::string(
365+
"FFmpeg CUDA Device Interface. Fallback status unknown (no frames decoded).");
366+
}
361367
return std::string("FFmpeg CUDA Device Interface. Using ") +
362368
(usingCPUFallback_ ? "CPU fallback." : "NVDEC.");
363369
}

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class CudaDeviceInterface : public DeviceInterface {
6363
std::unique_ptr<FilterGraph> nv12Conversion_;
6464

6565
bool usingCPUFallback_ = false;
66+
bool hasDecodedFrame_ = false;
6667
};
6768

6869
} // namespace facebook::torchcodec

src/torchcodec/decoders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
from .._core import AudioStreamMetadata, VideoStreamMetadata
88
from ._audio_decoder import AudioDecoder # noqa
99
from ._decoder_utils import set_cuda_backend # noqa
10-
from ._video_decoder import FallbackInfo, VideoDecoder # noqa
10+
from ._video_decoder import CpuFallbackStatus, VideoDecoder # noqa
1111

1212
SimpleVideoDecoder = VideoDecoder

src/torchcodec/decoders/_video_decoder.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,28 @@
2424

2525

2626
@dataclass
27-
class FallbackInfo:
28-
"""Information about decoder fallback status.
27+
class CpuFallbackStatus:
28+
"""Information about CPU fallback status.
2929
3030
This class tracks whether the decoder fell back to CPU decoding.
3131
3232
Usage:
33-
- Use ``str(fallback_info)`` or ``print(fallback_info)`` to see the cpu fallback status
34-
- Use ``bool(fallback_info)`` to check if any fallback occurred
33+
- Use ``str(cpu_fallback_status)`` or ``print(cpu_fallback_status)`` to see the cpu fallback status
34+
- Use ``bool(cpu_fallback_status)`` to check if any fallback occurred
3535
3636
Attributes:
3737
status_known (bool): Whether the fallback status has been determined.
3838
"""
3939

4040
def __init__(self):
4141
self.status_known = False
42-
self.__nvcuvid_unavailable = False
43-
self.__video_not_supported = False
42+
self._nvcuvid_unavailable = False
43+
self._video_not_supported = False
4444

4545
def __bool__(self):
4646
"""Returns True if fallback occurred."""
4747
return self.status_known and (
48-
self.__nvcuvid_unavailable or self.__video_not_supported
48+
self._nvcuvid_unavailable or self._video_not_supported
4949
)
5050

5151
def __str__(self):
@@ -54,9 +54,9 @@ def __str__(self):
5454
return "Fallback status: Unknown"
5555

5656
reasons = []
57-
if self.__nvcuvid_unavailable:
57+
if self._nvcuvid_unavailable:
5858
reasons.append("NVcuvid unavailable")
59-
if self.__video_not_supported:
59+
if self._video_not_supported:
6060
reasons.append("Video not supported")
6161

6262
if reasons:
@@ -142,6 +142,10 @@ class VideoDecoder:
142142
stream_index (int): The stream index that this decoder is retrieving frames from. If a
143143
stream index was provided at initialization, this is the same value. If it was left
144144
unspecified, this is the :term:`best stream`.
145+
cpu_fallback (CpuFallbackStatus): Information about whether the decoder fell back to CPU
146+
decoding. Use ``bool(cpu_fallback)`` to check if fallback occurred, or
147+
``str(cpu_fallback)`` to get a human-readable status message. The status is only
148+
determined after at least one frame has been decoded.
145149
"""
146150

147151
def __init__(
@@ -222,42 +226,33 @@ def __init__(
222226
custom_frame_mappings=custom_frame_mappings_data,
223227
)
224228

225-
self._fallback_info = FallbackInfo()
226-
self._has_decoded_frame = False
229+
self._cpu_fallback = CpuFallbackStatus()
227230

228231
def __len__(self) -> int:
229232
return self._num_frames
230233

231234
@property
232-
def cpu_fallback(self) -> FallbackInfo:
235+
def cpu_fallback(self) -> CpuFallbackStatus:
233236
# We can only determine whether fallback to CPU is happening when this
234237
# property is accessed and requires that at least one frame has been decoded.
235-
self._update_cpu_fallback()
236-
return self._fallback_info
237-
238-
def _update_cpu_fallback(self):
239-
"""Update the fallback status if it hasn't been determined yet.
240-
241-
This method queries the C++ backend to determine if fallback to CPU
242-
decoding occurred. The query is only performed after at least one frame
243-
has been decoded.
244-
"""
245-
if not self._fallback_info.status_known and self._has_decoded_frame:
238+
if not self._cpu_fallback.status_known:
246239
backend_details = core._get_backend_details(self._decoder)
247240

248-
self._fallback_info.status_known = True
241+
if "status unknown" not in backend_details:
242+
self._cpu_fallback.status_known = True
243+
244+
if "CPU fallback" in backend_details:
245+
if "NVCUVID not available" in backend_details:
246+
self._cpu_fallback._nvcuvid_unavailable = True
247+
else:
248+
self._cpu_fallback._video_not_supported = True
249249

250-
if "CPU fallback" in backend_details:
251-
if "NVCUVID not available" in backend_details:
252-
self._fallback_info._FallbackInfo__nvcuvid_unavailable = True
253-
else:
254-
self._fallback_info._FallbackInfo__video_not_supported = True
250+
return self._cpu_fallback
255251

256252
def _getitem_int(self, key: int) -> Tensor:
257253
assert isinstance(key, int)
258254

259255
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
260-
self._has_decoded_frame = True
261256
return frame_data
262257

263258
def _getitem_slice(self, key: slice) -> Tensor:
@@ -270,7 +265,6 @@ def _getitem_slice(self, key: slice) -> Tensor:
270265
stop=stop,
271266
step=step,
272267
)
273-
self._has_decoded_frame = True
274268
return frame_data
275269

276270
def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
@@ -324,7 +318,6 @@ def get_frame_at(self, index: int) -> Frame:
324318
data, pts_seconds, duration_seconds = core.get_frame_at_index(
325319
self._decoder, frame_index=index
326320
)
327-
self._has_decoded_frame = True
328321
return Frame(
329322
data=data,
330323
pts_seconds=pts_seconds.item(),
@@ -344,7 +337,6 @@ def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch:
344337
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
345338
self._decoder, frame_indices=indices
346339
)
347-
self._has_decoded_frame = True
348340

349341
return FrameBatch(
350342
data=data,
@@ -374,7 +366,6 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
374366
stop=stop,
375367
step=step,
376368
)
377-
self._has_decoded_frame = True
378369
return FrameBatch(*frames)
379370

380371
def get_frame_played_at(self, seconds: float) -> Frame:
@@ -404,7 +395,6 @@ def get_frame_played_at(self, seconds: float) -> Frame:
404395
data, pts_seconds, duration_seconds = core.get_frame_at_pts(
405396
self._decoder, seconds
406397
)
407-
self._has_decoded_frame = True
408398
return Frame(
409399
data=data,
410400
pts_seconds=pts_seconds.item(),
@@ -426,7 +416,6 @@ def get_frames_played_at(
426416
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
427417
self._decoder, timestamps=seconds
428418
)
429-
self._has_decoded_frame = True
430419
return FrameBatch(
431420
data=data,
432421
pts_seconds=pts_seconds,
@@ -471,7 +460,6 @@ def get_frames_played_in_range(
471460
start_seconds=start_seconds,
472461
stop_seconds=stop_seconds,
473462
)
474-
self._has_decoded_frame = True
475463
return FrameBatch(*frames)
476464

477465

test/test_decoders.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,19 +1737,6 @@ def test_set_cuda_backend(self):
17371737
with set_cuda_backend(backend):
17381738
VideoDecoder(H265_VIDEO.path, device=f"cuda:{bad_device_number}")
17391739

1740-
def test_cpu_fallback_before_after_decoding(self):
1741-
decoder = VideoDecoder(NASA_VIDEO.path)
1742-
1743-
# Before accessing any frames, status should be unknown
1744-
assert not decoder.cpu_fallback.status_known
1745-
assert str(decoder.cpu_fallback) == "Fallback status: Unknown"
1746-
assert not bool(decoder.cpu_fallback)
1747-
1748-
# After accessing frames, status should be known
1749-
_ = decoder[0]
1750-
assert decoder.cpu_fallback.status_known
1751-
assert str(decoder.cpu_fallback) != "Fallback status: Unknown"
1752-
17531740
def test_cpu_fallback_no_fallback_on_cpu_device(self):
17541741
"""Test that CPU device doesn't trigger fallback (it's not a fallback scenario)."""
17551742
decoder = VideoDecoder(NASA_VIDEO.path, device="cpu")
@@ -1767,6 +1754,8 @@ def test_cpu_fallback_h265_video_ffmpeg_cuda(self):
17671754
# because its dimensions are too small
17681755
decoder = VideoDecoder(H265_VIDEO.path, device="cuda")
17691756

1757+
assert not decoder.cpu_fallback.status_known
1758+
17701759
_ = decoder.get_frame_at(0)
17711760

17721761
assert decoder.cpu_fallback.status_known
@@ -1779,9 +1768,14 @@ def test_cpu_fallback_h265_video_beta_cuda(self):
17791768
with set_cuda_backend("beta"):
17801769
decoder = VideoDecoder(H265_VIDEO.path, device="cuda")
17811770

1771+
# Before accessing any frames, status should be unknown
1772+
assert decoder.cpu_fallback.status_known
1773+
17821774
_ = decoder.get_frame_at(0)
17831775

1776+
# After accessing frames, status should be known
17841777
assert decoder.cpu_fallback.status_known
1778+
17851779
assert bool(decoder.cpu_fallback)
17861780
assert "Fallback status: Falling back due to:" in str(decoder.cpu_fallback)
17871781

0 commit comments

Comments
 (0)