Skip to content

Commit 14045fa

Browse files
committed
Properly launder EOF exception
1 parent ba0063a commit 14045fa

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) {
240240

241241
OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
242242
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
243-
auto result = videoDecoder->getFramePlayedAt(seconds);
243+
VideoDecoder::FrameOutput result;
244+
try {
245+
result = videoDecoder->getFramePlayedAt(seconds);
246+
} catch (const VideoDecoder::EndOfFileException& e) {
247+
C10_THROW_ERROR(IndexError, e.what());
248+
}
244249
return makeOpsFrameOutput(result);
245250
}
246251

test/decoders/test_video_decoder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,14 @@ def test_get_frame_at(self, device, seek_mode):
350350
frame9 = decoder.get_frame_at(numpy.uint32(9))
351351
assert_frames_equal(ref_frame9, frame9.data)
352352

353+
@pytest.mark.parametrize("device", cpu_and_cuda())
354+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
355+
def test_get_frame_at_fails(self, device, seek_mode):
356+
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
357+
358+
with pytest.raises(IndexError, match="out of bounds"):
359+
decoder.get_frame_at(1000) # noqa
360+
353361
@pytest.mark.parametrize("device", cpu_and_cuda())
354362
def test_get_frame_at_tuple_unpacking(self, device):
355363
decoder = VideoDecoder(NASA_VIDEO.path, device=device)

test/decoders/test_video_decoder_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,18 @@ def test_get_frames_in_range(self, device):
308308
def test_throws_exception_at_eof(self, device):
309309
decoder = create_from_file(str(NASA_VIDEO.path))
310310
add_video_stream(decoder, device=device)
311+
312+
311313
seek_to_pts(decoder, 12.979633)
312314
last_frame, _, _ = get_next_frame(decoder)
313315
reference_last_frame = NASA_VIDEO.get_frame_data_by_index(289)
314316
assert_frames_equal(last_frame, reference_last_frame.to(device))
315317
with pytest.raises(IndexError, match="no more frames"):
316318
get_next_frame(decoder)
317319

320+
with pytest.raises(IndexError, match="no more frames"):
321+
get_frame_at_pts(decoder, seconds=1000.0)
322+
318323
@pytest.mark.parametrize("device", cpu_and_cuda())
319324
def test_throws_exception_if_seek_too_far(self, device):
320325
decoder = create_from_file(str(NASA_VIDEO.path))

0 commit comments

Comments
 (0)