Skip to content

Commit 98cf81b

Browse files
committed
Implement decoder native transforms API
1 parent 5344ab4 commit 98cf81b

File tree

2 files changed

+109
-74
lines changed

2 files changed

+109
-74
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import numbers
1010
from pathlib import Path
11-
from typing import Any, List, Literal, Optional, Tuple, Union
11+
from typing import Literal, Optional, Sequence, Tuple, Union
1212

1313
import torch
1414
from torch import device as torch_device, Tensor
@@ -19,6 +19,7 @@
1919
create_decoder,
2020
ERROR_REPORTING_INSTRUCTIONS,
2121
)
22+
from torchcodec.transforms import DecoderNativeTransform, Resize
2223

2324

2425
class VideoDecoder:
@@ -103,7 +104,7 @@ def __init__(
103104
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
104105
num_ffmpeg_threads: int = 1,
105106
device: Optional[Union[str, torch_device]] = "cpu",
106-
transforms: List[Any] = [], # TRANSFORMS TODO: what is the user-facing type?
107+
transforms: Optional[Sequence[DecoderNativeTransform]] = None,
107108
seek_mode: Literal["exact", "approximate"] = "exact",
108109
custom_frame_mappings: Optional[
109110
Union[str, bytes, io.RawIOBase, io.BufferedReader]
@@ -149,7 +150,7 @@ def __init__(
149150

150151
device_variant = _get_cuda_backend()
151152

152-
transform_specs = make_transform_specs(transforms)
153+
transform_specs = _make_transform_specs(transforms)
153154

154155
core.add_video_stream(
155156
self._decoder,
@@ -436,20 +437,56 @@ def _get_and_validate_stream_metadata(
436437
)
437438

438439

439-
def make_transform_specs(transforms: List[Any]) -> str:
440-
from torchvision.transforms import v2
440+
# This function, _make_transform_specs, and the transforms argument to
441+
# VideoDecoder actually accept a union of DecoderNativeTransform and
442+
# TorchVision transforms. We don't put that in our type annotation because
443+
# that would require importing torchvision at module scope which would mean we
444+
# have a hard dependency on torchvision.
445+
# TODO: better explanation of the above.
446+
def _convert_to_decoder_native_transforms(
447+
transforms: Sequence[DecoderNativeTransform],
448+
) -> Sequence[DecoderNativeTransform]:
449+
try:
450+
from torchvision.transforms import v2
451+
452+
tv_available = True
453+
except ImportError:
454+
tv_available = False
441455

442-
transform_specs = []
456+
converted_transforms = []
443457
for transform in transforms:
444-
if isinstance(transform, v2.Resize):
445-
if len(transform.size) != 2:
458+
if not isinstance(transform, DecoderNativeTransform):
459+
if not tv_available:
460+
raise ValueError(
461+
f"The supplied transform, {transform}, is not a TorchCodec "
462+
" DecoderNativeTransform. TorchCodec also accept TorchVision "
463+
"v2 transforms, but TorchVision is not installed."
464+
)
465+
if isinstance(transform, v2.Resize):
466+
if len(transform.size) != 2:
467+
raise ValueError(
468+
"TorchVision Resize transform must have a (height, width) "
469+
f"pair for the size, got {transform.size}."
470+
)
471+
converted_transforms.append(Resize(size=transform.size))
472+
else:
446473
raise ValueError(
447-
f"Resize transform must have a (height, width) pair for the size, got {transform.size}."
474+
f"Unsupported transform: {transform}. Transforms must be "
475+
"either a TorchCodec DecoderNativeTransform or a TorchVision "
476+
"v2 transform."
448477
)
449-
transform_specs.append(f"resize, {transform.size[0]}, {transform.size[1]}")
450478
else:
451-
raise ValueError(f"Unsupported transform {transform}.")
452-
return ";".join(transform_specs)
479+
converted_transforms.append(transform)
480+
481+
return converted_transforms
482+
483+
484+
def _make_transform_specs(transforms: Optional[Sequence[DecoderNativeTransform]]) -> str:
485+
if transforms is None:
486+
return ""
487+
488+
transforms = _convert_to_decoder_native_transforms(transforms)
489+
return ";".join([t.make_params() for t in transforms])
453490

454491

455492
def _read_custom_frame_mappings(

test/test_transform_ops.py

Lines changed: 60 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,66 @@
3737

3838

3939
class TestPublicVideoDecoderTransformOps:
40+
@pytest.mark.parametrize(
41+
"height_scaling_factor, width_scaling_factor",
42+
((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)),
43+
)
44+
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
45+
def test_resize_torchvision(
46+
self, video, height_scaling_factor, width_scaling_factor
47+
):
48+
height = int(video.get_height() * height_scaling_factor)
49+
width = int(video.get_width() * width_scaling_factor)
50+
51+
decoder_resize = VideoDecoder(
52+
video.path, transforms=[v2.Resize(size=(height, width))]
53+
)
54+
55+
decoder_full = VideoDecoder(video.path)
56+
57+
num_frames = len(decoder_resize)
58+
assert num_frames == len(decoder_full)
59+
60+
for frame_index in [
61+
0,
62+
int(num_frames * 0.1),
63+
int(num_frames * 0.2),
64+
int(num_frames * 0.3),
65+
int(num_frames * 0.4),
66+
int(num_frames * 0.5),
67+
int(num_frames * 0.75),
68+
int(num_frames * 0.90),
69+
num_frames - 1,
70+
]:
71+
expected_shape = (video.get_num_color_channels(), height, width)
72+
frame_resize = decoder_resize[frame_index]
73+
frame_full = decoder_full[frame_index]
74+
75+
frame_tv = v2.functional.resize(frame_full, size=(height, width))
76+
frame_tv_no_antialias = v2.functional.resize(
77+
frame_full, size=(height, width), antialias=False
78+
)
79+
80+
assert frame_resize.shape == expected_shape
81+
assert frame_tv.shape == expected_shape
82+
assert frame_tv_no_antialias.shape == expected_shape
83+
84+
assert_tensor_close_on_at_least(
85+
frame_resize, frame_tv, percentage=99.8, atol=1
86+
)
87+
torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6)
88+
89+
if height_scaling_factor < 1 or width_scaling_factor < 1:
90+
# Antialias only relevant when down-scaling!
91+
with pytest.raises(AssertionError, match="Expected at least"):
92+
assert_tensor_close_on_at_least(
93+
frame_resize, frame_tv_no_antialias, percentage=99, atol=1
94+
)
95+
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
96+
torch.testing.assert_close(
97+
frame_resize, frame_tv_no_antialias, rtol=0, atol=6
98+
)
99+
40100
def test_resize_fails(self):
41101
with pytest.raises(
42102
ValueError,
@@ -187,68 +247,6 @@ def test_transform_fails(self):
187247
):
188248
add_video_stream(decoder, transform_specs="invalid, 1, 2")
189249

190-
@pytest.mark.parametrize(
191-
"height_scaling_factor, width_scaling_factor",
192-
((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)),
193-
)
194-
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
195-
def test_resize_torchvision(
196-
self, video, height_scaling_factor, width_scaling_factor
197-
):
198-
num_frames = self.get_num_frames_core_ops(video)
199-
200-
height = int(video.get_height() * height_scaling_factor)
201-
width = int(video.get_width() * width_scaling_factor)
202-
resize_spec = f"resize, {height}, {width}"
203-
204-
decoder_resize = create_from_file(str(video.path))
205-
add_video_stream(decoder_resize, transform_specs=resize_spec)
206-
207-
decoder_full = create_from_file(str(video.path))
208-
add_video_stream(decoder_full)
209-
210-
for frame_index in [
211-
0,
212-
int(num_frames * 0.1),
213-
int(num_frames * 0.2),
214-
int(num_frames * 0.3),
215-
int(num_frames * 0.4),
216-
int(num_frames * 0.5),
217-
int(num_frames * 0.75),
218-
int(num_frames * 0.90),
219-
num_frames - 1,
220-
]:
221-
expected_shape = (video.get_num_color_channels(), height, width)
222-
frame_resize, *_ = get_frame_at_index(
223-
decoder_resize, frame_index=frame_index
224-
)
225-
226-
frame_full, *_ = get_frame_at_index(decoder_full, frame_index=frame_index)
227-
frame_tv = v2.functional.resize(frame_full, size=(height, width))
228-
frame_tv_no_antialias = v2.functional.resize(
229-
frame_full, size=(height, width), antialias=False
230-
)
231-
232-
assert frame_resize.shape == expected_shape
233-
assert frame_tv.shape == expected_shape
234-
assert frame_tv_no_antialias.shape == expected_shape
235-
236-
assert_tensor_close_on_at_least(
237-
frame_resize, frame_tv, percentage=99.8, atol=1
238-
)
239-
torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6)
240-
241-
if height_scaling_factor < 1 or width_scaling_factor < 1:
242-
# Antialias only relevant when down-scaling!
243-
with pytest.raises(AssertionError, match="Expected at least"):
244-
assert_tensor_close_on_at_least(
245-
frame_resize, frame_tv_no_antialias, percentage=99, atol=1
246-
)
247-
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
248-
torch.testing.assert_close(
249-
frame_resize, frame_tv_no_antialias, rtol=0, atol=6
250-
)
251-
252250
def test_resize_ffmpeg(self):
253251
height = 135
254252
width = 240

0 commit comments

Comments
 (0)