Skip to content

Commit 4245bdd

Browse files
committed
More involved testing
1 parent a95dd4c commit 4245bdd

File tree

3 files changed

+72
-28
lines changed

3 files changed

+72
-28
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,18 @@ void CpuDeviceInterface::initializeVideo(
6969
first = false;
7070
}
7171
if (!transforms.empty()) {
72-
// Note that we ensure that the transforms come AFTER the format conversion.
73-
// This means that the transforms are applied in the output pixel format and
74-
// colorspace.
75-
filters_ += "," + filters.str();
72+
// Note [Transform and Format Conversion Order]
73+
// We have to ensure that all user filters happen AFTER the explicit format
74+
// conversion. That is, we want the filters to be applied in RGB24, not the
75+
// pixel format of the input frame.
76+
//
77+
// The ouput frame will always be in RGB24, as we specify the sink node with
78+
// AV_PIX_FORMAT_RGB24. Filtergraph will automatically insert a filter
79+
// conversion to ensure the output frame matches the pixel format
80+
// specified in the sink. But by default, it will insert it after the user
81+
// filters. We need an explicit format conversion to get the behavior we
82+
// want.
83+
filters_ = "format=rgb24," + filters.str();
7684
}
7785

7886
initialized_ = true;
@@ -233,9 +241,14 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
233241
swsContext_ = createSwsContext(
234242
swsFrameContext,
235243
avFrame->colorspace,
244+
245+
// See [Transform and Format Conversion Order] for more on the output
246+
// pixel format.
236247
/*outputFormat=*/AV_PIX_FMT_RGB24,
237-
/*swsFlags=*/0); // We don't set any flags because we don't yet use
238-
// sws_scale() for resizing.
248+
249+
// We don't set any flags because we don't yet use sw_scale() for
250+
// resizing.
251+
/*swsFlags=*/0);
239252
prevSwsFrameContext_ = swsFrameContext;
240253
}
241254

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,13 @@ class CpuDeviceInterface : public DeviceInterface {
9393
// initialization, we convert the user-supplied transforms into this string of
9494
// filters.
9595
//
96-
// Note that we start with just the format conversion, and then we ensure that
97-
// the user-supplied filters always happen AFTER the format conversion. We
98-
// want the user-supplied filters to operate on frames in the output pixel
99-
// format and colorspace.
96+
// Note that if there are no user-supplied transforms, then the default filter
97+
// we use is the copy filter, which is just an identity: it emits the output
98+
// frame unchanged. We supply such a filter because we can't supply just the
99+
// empty-string; we must supply SOME filter.
100100
//
101-
// We apply the transforms on the output pixel format and colorspace because
102-
// then decoder-native transforms are as close as possible to returning
103-
// untransformed frames and applying TochVision transforms to them.
104-
//
105-
// We ensure that the transforms happen on the output pixel format and
106-
// colorspace by making sure all of the user-supplied filters happen AFTER
107-
// an explicit format conversion.
108-
std::string filters_ = "format=rgb24";
101+
// See also [Tranform and Format Conversion Order] for more on filters.
102+
std::string filters_ = "copy";
109103

110104
// Values set during initialization and referred to in
111105
// getColorConversionLibrary().

test/test_transform_ops.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,25 @@
3131
H265_VIDEO,
3232
NASA_VIDEO,
3333
needs_cuda,
34+
TEST_SRC_2_720P,
3435
)
3536

3637
torch._dynamo.config.capture_dynamic_output_shape_ops = True
3738

3839

3940
class TestCoreVideoDecoderTransformOps:
40-
@pytest.mark.parametrize("video", [NASA_VIDEO, H265_VIDEO, AV1_VIDEO])
41-
def test_color_conversion_library(self, video):
41+
def get_num_frames_core_ops(self, video):
4242
decoder = create_from_file(str(video.path))
4343
add_video_stream(decoder)
4444
metadata = get_json_metadata(decoder)
4545
metadata_dict = json.loads(metadata)
4646
num_frames = metadata_dict["numFramesFromHeader"]
47+
assert num_frames is not None
48+
return num_frames
49+
50+
@pytest.mark.parametrize("video", [NASA_VIDEO, H265_VIDEO, AV1_VIDEO])
51+
def test_color_conversion_library(self, video):
52+
num_frames = self.get_num_frames_core_ops(video)
4753

4854
filtergraph_decoder = create_from_file(str(video.path))
4955
_add_video_stream(
@@ -170,32 +176,63 @@ def test_transform_fails(self):
170176
"height_scaling_factor, width_scaling_factor",
171177
((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)),
172178
)
173-
def test_resize_torchvision(self, height_scaling_factor, width_scaling_factor):
174-
height = int(NASA_VIDEO.get_height() * height_scaling_factor)
175-
width = int(NASA_VIDEO.get_width() * width_scaling_factor)
179+
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
180+
def test_resize_torchvision(
181+
self, video, height_scaling_factor, width_scaling_factor
182+
):
183+
num_frames = self.get_num_frames_core_ops(video)
184+
185+
height = int(video.get_height() * height_scaling_factor)
186+
width = int(video.get_width() * width_scaling_factor)
176187
resize_spec = f"resize, {height}, {width}"
177188

178-
decoder_resize = create_from_file(str(NASA_VIDEO.path))
189+
decoder_resize = create_from_file(str(video.path))
179190
add_video_stream(decoder_resize, transform_specs=resize_spec)
180191

181-
decoder_full = create_from_file(str(NASA_VIDEO.path))
192+
decoder_full = create_from_file(str(video.path))
182193
add_video_stream(decoder_full)
183194

184-
for frame_index in [0, 10, 17, 100, 230, 389]:
185-
expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width)
195+
for frame_index in [
196+
0,
197+
int(num_frames * 0.1),
198+
int(num_frames * 0.2),
199+
int(num_frames * 0.3),
200+
int(num_frames * 0.4),
201+
int(num_frames * 0.5),
202+
int(num_frames * 0.75),
203+
int(num_frames * 0.90),
204+
num_frames - 1,
205+
]:
206+
expected_shape = (video.get_num_color_channels(), height, width)
186207
frame_resize, *_ = get_frame_at_index(
187208
decoder_resize, frame_index=frame_index
188209
)
189210

190211
frame_full, *_ = get_frame_at_index(decoder_full, frame_index=frame_index)
191212
frame_tv = v2.functional.resize(frame_full, size=(height, width))
213+
frame_tv_no_antialias = v2.functional.resize(
214+
frame_full, size=(height, width), antialias=False
215+
)
192216

193217
assert frame_resize.shape == expected_shape
194218
assert frame_tv.shape == expected_shape
219+
assert frame_tv_no_antialias.shape == expected_shape
195220

196221
assert_tensor_close_on_at_least(
197-
frame_resize, frame_tv, percentage=99, atol=1
222+
frame_resize, frame_tv, percentage=99.9, atol=1
198223
)
224+
torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6)
225+
226+
if height_scaling_factor < 1 or width_scaling_factor < 1:
227+
# Antialias only relevant when down-scaling!
228+
with pytest.raises(AssertionError, match="Expected at least"):
229+
assert_tensor_close_on_at_least(
230+
frame_resize, frame_tv_no_antialias, percentage=99, atol=1
231+
)
232+
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
233+
torch.testing.assert_close(
234+
frame_resize, frame_tv_no_antialias, rtol=0, atol=6
235+
)
199236

200237
def test_resize_ffmpeg(self):
201238
height = 135

0 commit comments

Comments
 (0)