Skip to content

Commit 6e69c8c

Browse files
author
Molly Xu
committed
modify comments
1 parent cf5b718 commit 6e69c8c

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -223,23 +223,26 @@ def __init__(
223223
)
224224

225225
self._fallback_info = FallbackInfo()
226+
self._has_decoded_frame = False
226227

227228
def __len__(self) -> int:
228229
return self._num_frames
229230

230231
@property
231232
def cpu_fallback(self) -> FallbackInfo:
232-
# We can only determine whether fallback to CPU is happening after
233-
# the first frame access. Before that, the status will be "Unknown".
233+
# We can only determine whether fallback to CPU is happening when this
234+
# property is accessed and requires that at least one frame has been decoded.
235+
self._update_cpu_fallback()
234236
return self._fallback_info
235237

236238
def _update_cpu_fallback(self):
237239
"""Update the fallback status if it hasn't been determined yet.
238240
239-
This method should be called after any frame decoding operation to determine
240-
if fallback to software decoding occurred.
241+
This method queries the C++ backend to determine if fallback to CPU
242+
decoding occurred. The query is only performed after at least one frame
243+
has been decoded.
241244
"""
242-
if not self._fallback_info.status_known:
245+
if not self._fallback_info.status_known and self._has_decoded_frame:
243246
backend_details = core._get_backend_details(self._decoder)
244247

245248
self._fallback_info.status_known = True
@@ -254,7 +257,7 @@ def _getitem_int(self, key: int) -> Tensor:
254257
assert isinstance(key, int)
255258

256259
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
257-
self._update_cpu_fallback()
260+
self._has_decoded_frame = True
258261
return frame_data
259262

260263
def _getitem_slice(self, key: slice) -> Tensor:
@@ -267,7 +270,7 @@ def _getitem_slice(self, key: slice) -> Tensor:
267270
stop=stop,
268271
step=step,
269272
)
270-
self._update_cpu_fallback()
273+
self._has_decoded_frame = True
271274
return frame_data
272275

273276
def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
@@ -321,7 +324,7 @@ def get_frame_at(self, index: int) -> Frame:
321324
data, pts_seconds, duration_seconds = core.get_frame_at_index(
322325
self._decoder, frame_index=index
323326
)
324-
self._update_cpu_fallback()
327+
self._has_decoded_frame = True
325328
return Frame(
326329
data=data,
327330
pts_seconds=pts_seconds.item(),
@@ -341,7 +344,7 @@ def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch:
341344
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
342345
self._decoder, frame_indices=indices
343346
)
344-
self._update_cpu_fallback()
347+
self._has_decoded_frame = True
345348

346349
return FrameBatch(
347350
data=data,
@@ -371,7 +374,7 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
371374
stop=stop,
372375
step=step,
373376
)
374-
self._update_cpu_fallback()
377+
self._has_decoded_frame = True
375378
return FrameBatch(*frames)
376379

377380
def get_frame_played_at(self, seconds: float) -> Frame:
@@ -401,7 +404,7 @@ def get_frame_played_at(self, seconds: float) -> Frame:
401404
data, pts_seconds, duration_seconds = core.get_frame_at_pts(
402405
self._decoder, seconds
403406
)
404-
self._update_cpu_fallback()
407+
self._has_decoded_frame = True
405408
return Frame(
406409
data=data,
407410
pts_seconds=pts_seconds.item(),
@@ -423,7 +426,7 @@ def get_frames_played_at(
423426
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
424427
self._decoder, timestamps=seconds
425428
)
426-
self._update_cpu_fallback()
429+
self._has_decoded_frame = True
427430
return FrameBatch(
428431
data=data,
429432
pts_seconds=pts_seconds,
@@ -468,7 +471,7 @@ def get_frames_played_in_range(
468471
start_seconds=start_seconds,
469472
stop_seconds=stop_seconds,
470473
)
471-
self._update_cpu_fallback()
474+
self._has_decoded_frame = True
472475
return FrameBatch(*frames)
473476

474477

0 commit comments

Comments
 (0)