Skip to content

Commit 1670372

Browse files
authored
Properly launder EOF exception (#520)
1 parent afb63c0 commit 1670372

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) {
240240

241241
OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
242242
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
243-
auto result = videoDecoder->getFramePlayedAt(seconds);
243+
VideoDecoder::FrameOutput result;
244+
try {
245+
result = videoDecoder->getFramePlayedAt(seconds);
246+
} catch (const VideoDecoder::EndOfFileException& e) {
247+
C10_THROW_ERROR(IndexError, e.what());
248+
}
244249
return makeOpsFrameOutput(result);
245250
}
246251

test/decoders/test_video_decoder_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,17 @@ def test_get_frames_in_range(self, device):
308308
def test_throws_exception_at_eof(self, device):
309309
decoder = create_from_file(str(NASA_VIDEO.path))
310310
add_video_stream(decoder, device=device)
311+
311312
seek_to_pts(decoder, 12.979633)
312313
last_frame, _, _ = get_next_frame(decoder)
313314
reference_last_frame = NASA_VIDEO.get_frame_data_by_index(289)
314315
assert_frames_equal(last_frame, reference_last_frame.to(device))
315316
with pytest.raises(IndexError, match="no more frames"):
316317
get_next_frame(decoder)
317318

319+
with pytest.raises(IndexError, match="no more frames"):
320+
get_frame_at_pts(decoder, seconds=1000.0)
321+
318322
@pytest.mark.parametrize("device", cpu_and_cuda())
319323
def test_throws_exception_if_seek_too_far(self, device):
320324
decoder = create_from_file(str(NASA_VIDEO.path))

0 commit comments

Comments
 (0)