Skip to content

Commit 9387537

Browse files
committed
Test, and fix
1 parent 5113b9c commit 9387537

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,8 +1073,10 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10731073
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);
10741074

10751075
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
1076-
DecodedOutput singleOut =
1077-
getFrameAtIndex(streamIndex, i, output.frames[f]);
1076+
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1077+
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1078+
output.frames[f] = singleOut.frame;
1079+
}
10781080
output.ptsSeconds[f] = singleOut.ptsSeconds;
10791081
output.durationSeconds[f] = singleOut.durationSeconds;
10801082
}
@@ -1166,8 +1168,10 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11661168
int64_t numFrames = stopFrameIndex - startFrameIndex;
11671169
BatchDecodedOutput output(numFrames, options, streamMetadata);
11681170
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
1169-
DecodedOutput singleOut =
1170-
getFrameAtIndex(streamIndex, i, output.frames[f]);
1171+
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1172+
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1173+
output.frames[f] = singleOut.frame;
1174+
}
11711175
output.ptsSeconds[f] = singleOut.ptsSeconds;
11721176
output.durationSeconds[f] = singleOut.durationSeconds;
11731177
}

test/decoders/test_video_decoder_ops.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_frame_at_index,
2828
get_frame_at_pts,
2929
get_frames_at_indices,
30+
get_frames_by_pts_in_range,
3031
get_frames_in_range,
3132
get_json_metadata,
3233
get_next_frame,
@@ -383,6 +384,48 @@ def test_color_conversion_library_with_scaling(
383384
swscale_frame0, _, _ = get_next_frame(swscale_decoder)
384385
assert_tensor_equal(filtergraph_frame0, swscale_frame0)
385386

387+
@pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW"))
388+
@pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))
389+
def test_color_conversion_library_with_dimension_order(
390+
self, dimension_order, color_conversion_library
391+
):
392+
decoder = create_from_file(str(NASA_VIDEO.path))
393+
_add_video_stream(
394+
decoder,
395+
color_conversion_library=color_conversion_library,
396+
dimension_order=dimension_order,
397+
)
398+
scan_all_streams_to_update_metadata(decoder)
399+
400+
frame0_ref = NASA_VIDEO.get_frame_data_by_index(0)
401+
C, H, W = frame0_ref.shape
402+
if dimension_order == "NHWC":
403+
frame0_ref = frame0_ref.permute(1, 2, 0)
404+
expected_shape = frame0_ref.shape
405+
406+
stream_index = 3
407+
frame0, *_ = get_frame_at_index(
408+
decoder, stream_index=stream_index, frame_index=0
409+
)
410+
assert frame0.shape == expected_shape
411+
assert_tensor_equal(frame0, frame0_ref)
412+
413+
frame0, *_ = get_frame_at_pts(decoder, seconds=0.0)
414+
assert frame0.shape == expected_shape
415+
assert_tensor_equal(frame0, frame0_ref)
416+
417+
frames, *_ = get_frames_in_range(
418+
decoder, stream_index=stream_index, start=0, stop=3
419+
)
420+
assert frames.shape[1:] == expected_shape
421+
assert_tensor_equal(frames[0], frame0_ref)
422+
423+
frames, *_ = get_frames_by_pts_in_range(
424+
decoder, stream_index=stream_index, start_seconds=0, stop_seconds=1
425+
)
426+
assert frames.shape[1:] == expected_shape
427+
assert_tensor_equal(frames[0], frame0_ref)
428+
386429
@pytest.mark.parametrize(
387430
"width_scaling_factor,height_scaling_factor",
388431
((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),

0 commit comments

Comments
 (0)