Skip to content

Commit 7668eea

Browse files
committed
Remove resize from color conversion library tests
1 parent 2a391b6 commit 7668eea

File tree

1 file changed

+33
-46
lines changed

1 file changed

+33
-46
lines changed

test/test_transform_ops.py

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
create_from_file,
2121
get_frame_at_index,
2222
get_json_metadata,
23-
get_next_frame,
2423
)
2524

2625
from torchvision.transforms import v2
2726

2827
from .utils import (
2928
assert_frames_equal,
3029
assert_tensor_close_on_at_least,
30+
AV1_VIDEO,
31+
H265_VIDEO,
3132
NASA_VIDEO,
3233
needs_cuda,
3334
)
@@ -36,56 +37,46 @@
3637

3738

3839
class TestCoreVideoDecoderTransformOps:
39-
# We choose arbitrary values for width and height scaling to get better
40-
# test coverage. Some pairs upscale the image while others downscale it.
41-
@pytest.mark.parametrize(
42-
"width_scaling_factor,height_scaling_factor",
43-
((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),
44-
)
45-
@pytest.mark.parametrize("input_video", [NASA_VIDEO])
46-
def test_color_conversion_library_with_scaling(
47-
self, input_video, width_scaling_factor, height_scaling_factor
48-
):
49-
decoder = create_from_file(str(input_video.path))
40+
@pytest.mark.parametrize("video", [NASA_VIDEO, H265_VIDEO, AV1_VIDEO])
41+
def test_color_conversion_library(self, video):
42+
decoder = create_from_file(str(video.path))
5043
add_video_stream(decoder)
5144
metadata = get_json_metadata(decoder)
5245
metadata_dict = json.loads(metadata)
53-
assert metadata_dict["width"] == input_video.width
54-
assert metadata_dict["height"] == input_video.height
55-
56-
target_height = int(input_video.height * height_scaling_factor)
57-
target_width = int(input_video.width * width_scaling_factor)
58-
if width_scaling_factor != 1.0:
59-
assert target_width != input_video.width
60-
if height_scaling_factor != 1.0:
61-
assert target_height != input_video.height
46+
num_frames = metadata_dict["numFramesFromHeader"]
6247

63-
filtergraph_decoder = create_from_file(str(input_video.path))
48+
filtergraph_decoder = create_from_file(str(video.path))
6449
_add_video_stream(
6550
filtergraph_decoder,
66-
transform_specs=f"resize, {target_height}, {target_width}",
6751
color_conversion_library="filtergraph",
6852
)
69-
filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
7053

71-
swscale_decoder = create_from_file(str(input_video.path))
54+
swscale_decoder = create_from_file(str(video.path))
7255
_add_video_stream(
7356
swscale_decoder,
74-
transform_specs=f"resize, {target_height}, {target_width}",
7557
color_conversion_library="swscale",
7658
)
77-
swscale_frame0, _, _ = get_next_frame(swscale_decoder)
78-
assert_frames_equal(filtergraph_frame0, swscale_frame0)
79-
assert filtergraph_frame0.shape == (3, target_height, target_width)
8059

81-
@pytest.mark.parametrize(
82-
"width_scaling_factor,height_scaling_factor",
83-
((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),
84-
)
60+
for frame_index in [
61+
0,
62+
int(num_frames * 0.25),
63+
int(num_frames * 0.5),
64+
int(num_frames * 0.75),
65+
num_frames - 1,
66+
]:
67+
filtergraph_frame, *_ = get_frame_at_index(
68+
filtergraph_decoder, frame_index=frame_index
69+
)
70+
swscale_frame, *_ = get_frame_at_index(
71+
swscale_decoder, frame_index=frame_index
72+
)
73+
74+
assert_frames_equal(filtergraph_frame, swscale_frame)
75+
8576
@pytest.mark.parametrize("width", [30, 32, 300])
8677
@pytest.mark.parametrize("height", [128])
8778
def test_color_conversion_library_with_generated_videos(
88-
self, tmp_path, width, height, width_scaling_factor, height_scaling_factor
79+
self, tmp_path, width, height
8980
):
9081
# We consider filtergraph to be the reference color conversion library.
9182
# However the video decoder sometimes uses swscale as that is faster.
@@ -134,28 +125,24 @@ def test_color_conversion_library_with_generated_videos(
134125
assert metadata_dict["width"] == width
135126
assert metadata_dict["height"] == height
136127

137-
target_height = int(height * height_scaling_factor)
138-
target_width = int(width * width_scaling_factor)
139-
if width_scaling_factor != 1.0:
140-
assert target_width != width
141-
if height_scaling_factor != 1.0:
142-
assert target_height != height
128+
num_frames = metadata_dict["numFramesFromHeader"]
129+
assert num_frames is not None and num_frames == 1
143130

144131
filtergraph_decoder = create_from_file(str(video_path))
145132
_add_video_stream(
146133
filtergraph_decoder,
147-
transform_specs=f"resize, {target_height}, {target_width}",
148134
color_conversion_library="filtergraph",
149135
)
150-
filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
151136

152137
auto_decoder = create_from_file(str(video_path))
153-
add_video_stream(
138+
_add_video_stream(
154139
auto_decoder,
155-
transform_specs=f"resize, {target_height}, {target_width}",
140+
color_conversion_library="swscale",
156141
)
157-
auto_frame0, _, _ = get_next_frame(auto_decoder)
158-
assert_frames_equal(filtergraph_frame0, auto_frame0)
142+
143+
filtergraph_frame0, *_ = get_frame_at_index(filtergraph_decoder, frame_index=0)
144+
swscale_frame0, *_ = get_frame_at_index(auto_decoder, frame_index=0)
145+
assert_frames_equal(filtergraph_frame0, swscale_frame0)
159146

160147
@needs_cuda
161148
def test_scaling_on_cuda_fails(self):

0 commit comments

Comments
 (0)