Skip to content

Commit d8b7ed0

Browse files
committed
Refactor all the things
1 parent 8e6a8f2 commit d8b7ed0

File tree

3 files changed

+99
-136
lines changed

3 files changed

+99
-136
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import numbers
1010
from pathlib import Path
11-
from typing import List, Literal, Optional, Sequence, Tuple, Union
11+
from typing import Literal, Optional, Sequence, Tuple, Union
1212

1313
import torch
1414
from torch import device as torch_device, nn, Tensor
@@ -170,7 +170,6 @@ def __init__(
170170
transform_specs = _make_transform_specs(
171171
transforms,
172172
input_dims=(self.metadata.height, self.metadata.width),
173-
dimension_order=dimension_order,
174173
)
175174

176175
core.add_video_stream(
@@ -452,96 +451,97 @@ def _get_and_validate_stream_metadata(
452451
)
453452

454453

455-
def _convert_to_decoder_transforms(
456-
transforms: Sequence[Union[DecoderTransform, nn.Module]],
454+
def _make_transform_specs(
455+
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
457456
input_dims: Tuple[Optional[int], Optional[int]],
458-
dimension_order: Literal["NCHW", "NHWC"],
459-
) -> List[DecoderTransform]:
460-
"""Convert a sequence of transforms that may contain TorchVision transform
461-
objects into a list of only TorchCodec transform objects.
457+
) -> str:
458+
"""Given a sequence of transforms, turn those into the specification string
459+
the core API expects.
462460
463461
Args:
464-
transforms: Squence of transform objects. The objects can be one of two
465-
types:
462+
transforms: Optional sequence of transform objects. The objects can be
463+
one of two types:
466464
1. torchcodec.transforms.DecoderTransform
467465
2. torchvision.transforms.v2.Transform, but our type annotation
468466
only mentions its base, nn.Module. We don't want to take a
469467
hard dependency on TorchVision.
468+
input_dims: Optional (height, width) pair. Note that only some
469+
transforms need to know the dimensions. If the user provides
470+
transforms that don't need to know the dimensions, and that metadata
471+
is missing, everything should still work. That means we assert their
472+
existence as late as possible.
470473
471474
Returns:
472-
List of DecoderTransform objects.
475+
String of transforms in the format the core API expects: transform
476+
specifications separate by semicolons.
473477
"""
478+
if transforms is None:
479+
return ""
480+
474481
try:
475482
from torchvision.transforms import v2
476483

477484
tv_available = True
478485
except ImportError:
479486
tv_available = False
480487

481-
converted_transforms: list[DecoderTransform] = []
488+
# The following loop accomplishes two tasks:
489+
#
490+
# 1. Converts the transform to a DecoderTransform, if necessary. We
491+
# accept TorchVision transform objects and they must be converted
492+
# to their matching DecoderTransform.
493+
# 2. Calculates what the input dimensions are to each transform.
494+
#
495+
# The order in our transforms list is semantically meaningful, as we
496+
# actually have a pipeline where the output of one transform is the input to
497+
# the next. For example, if we have the transforms list [A, B, C, D], then
498+
# we should understand that as:
499+
# A -> B -> C -> D
500+
# Where the frame produced by A is the input to B, the frame produced by B
501+
# is the input to C, etc. This particularly matters for frame dimensions.
502+
# Transforms can both:
503+
#
504+
# 1. Produce frames with arbitrary dimensions.
505+
# 2. Rely on their input frame's dimensions to calculate ahead-of-time
506+
# what their runtime behavior will be.
507+
#
508+
# The consequence of the above facts is that we need to statically track
509+
# frame dimensions in the pipeline while we pre-process it. The input
510+
# frame's dimensions to A, our first transform, is always what we know from
511+
# our metadata. For each transform, we always calculate its output
512+
# dimensions from its input dimensions. We store these with the converted
513+
# transform, to be all used together when we generate the specs.
514+
converted_transforms: list[(DecoderTransform, Tuple[int, int])] = []
515+
curr_input_dims = input_dims
482516
for transform in transforms:
483-
if not isinstance(transform, DecoderTransform):
517+
if isinstance(transform, DecoderTransform):
518+
output_dims = transform._calculate_output_dims(curr_input_dims)
519+
converted_transforms.append((transform, curr_input_dims))
520+
else:
484521
if not tv_available:
485522
raise ValueError(
486523
f"The supplied transform, {transform}, is not a TorchCodec "
487-
" DecoderTransform. TorchCodec also accept TorchVision "
524+
" DecoderTransform. TorchCodec also accepts TorchVision "
488525
"v2 transforms, but TorchVision is not installed."
489526
)
490527
elif isinstance(transform, v2.Resize):
491-
transform_tc = Resize._from_torchvision(transform)
492-
input_dims = transform_tc._get_output_dims(input_dims)
493-
converted_transforms.append(transform_tc)
528+
tc_transform = Resize._from_torchvision(transform)
529+
output_dims = tc_transform._calculate_output_dims(curr_input_dims)
530+
converted_transforms.append((tc_transform, curr_input_dims))
494531
elif isinstance(transform, v2.RandomCrop):
495-
if dimension_order != "NCHW":
496-
raise ValueError(
497-
"TorchVision v2 RandomCrop is only supported for NCHW "
498-
"dimension order. Please use the TorchCodec RandomCrop "
499-
"transform instead."
500-
)
501-
transform_tc = RandomCrop._from_torchvision(
502-
transform,
503-
input_dims,
504-
)
505-
input_dims = transform_tc._get_output_dims(input_dims)
506-
converted_transforms.append(transform_tc)
532+
tc_transform = RandomCrop._from_torchvision(transform)
533+
output_dims = tc_transform._calculate_output_dims(curr_input_dims)
534+
converted_transforms.append((tc_transform, curr_input_dims))
507535
else:
508536
raise ValueError(
509537
f"Unsupported transform: {transform}. Transforms must be "
510538
"either a TorchCodec DecoderTransform or a TorchVision "
511539
"v2 transform."
512540
)
513-
else:
514-
input_dims = transform._get_output_dims(input_dims)
515-
converted_transforms.append(transform)
516-
517-
return converted_transforms
518-
519541

520-
def _make_transform_specs(
521-
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
522-
input_dims: Tuple[Optional[int], Optional[int]],
523-
dimension_order: Literal["NCHW", "NHWC"],
524-
) -> str:
525-
"""Given a sequence of transforms, turn those into the specification string
526-
the core API expects.
527-
528-
Args:
529-
transforms: Optional sequence of transform objects. The objects can be
530-
one of two types:
531-
1. torchcodec.transforms.DecoderTransform
532-
2. torchvision.transforms.v2.Transform, but our type annotation
533-
only mentions its base, nn.Module. We don't want to take a
534-
hard dependency on TorchVision.
535-
536-
Returns:
537-
String of transforms in the format the core API expects: transform
538-
specifications separate by semicolons.
539-
"""
540-
if transforms is None:
541-
return ""
542+
curr_input_dims = output_dims
542543

543-
transforms = _convert_to_decoder_transforms(transforms, input_dims, dimension_order)
544-
return ";".join([t._make_transform_spec() for t in transforms])
544+
return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])
545545

546546

547547
def _read_custom_frame_mappings(

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ class DecoderTransform(ABC):
3838
"""
3939

4040
@abstractmethod
41-
def _make_transform_spec(self) -> str:
41+
def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str:
4242
pass
4343

44-
def _get_output_dims(
44+
def _calculate_output_dims(
4545
self, input_dims: Tuple[Optional[int], Optional[int]]
4646
) -> Tuple[Optional[int], Optional[int]]:
4747
return input_dims
@@ -71,12 +71,12 @@ class Resize(DecoderTransform):
7171

7272
size: Sequence[int]
7373

74-
def _make_transform_spec(self) -> str:
74+
def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str:
7575
# TODO: establish this invariant in the constructor during refactor
7676
assert len(self.size) == 2
7777
return f"resize, {self.size[0]}, {self.size[1]}"
7878

79-
def _get_output_dims(
79+
def _calculate_output_dims(
8080
self, input_dims: Tuple[Optional[int], Optional[int]]
8181
) -> Tuple[Optional[int], Optional[int]]:
8282
# TODO: establish this invariant in the constructor during refactor
@@ -125,45 +125,37 @@ class RandomCrop(DecoderTransform):
125125
"""
126126

127127
size: Sequence[int]
128+
129+
# Note that these values are never read by this object or the decoder. We
130+
# record them for testing purposes only.
128131
_top: Optional[int] = None
129132
_left: Optional[int] = None
130-
_input_dims: Optional[Tuple[int, int]] = None
131133

132-
def _make_transform_spec(self) -> str:
134+
def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str:
133135
if len(self.size) != 2:
134136
raise ValueError(
135137
f"RandomCrop's size must be a sequence of length 2, got {self.size}. "
136138
"This should never happen, please report a bug."
137139
)
138140

139-
if self._top is None or self._left is None:
140-
# TODO: It would be very strange if only ONE of those is None. But should we
141-
# make it an error? We can continue, but it would probably mean
142-
# something bad happened. Dear reviewer, please register an opinion here:
143-
if self._input_dims is None:
144-
raise ValueError(
145-
"RandomCrop's input_dims must be set before calling _make_transform_spec(). "
146-
"This should never happen, please report a bug."
147-
)
148-
if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]:
149-
raise ValueError(
150-
f"Input dimensions {self._input_dims} are smaller than the crop size {self.size}."
151-
)
152-
153-
# Note: This logic must match the logic in
154-
# torchvision.transforms.v2.RandomCrop.make_params(). Given
155-
# the same seed, they should get the same result. This is an
156-
# API guarantee with our users.
157-
self._top = int(
158-
torch.randint(0, self._input_dims[0] - self.size[0] + 1, size=()).item()
159-
)
160-
self._left = int(
161-
torch.randint(0, self._input_dims[1] - self.size[1] + 1, size=()).item()
141+
# Note: This logic below must match the logic in
142+
# torchvision.transforms.v2.RandomCrop.make_params(). Given
143+
# the same seed, they should get the same result. This is an
144+
# API guarantee with our users.
145+
if input_dims[0] < self.size[0] or input_dims[1] < self.size[1]:
146+
raise ValueError(
147+
f"Input dimensions {input_dims} are smaller than the crop size {self.size}."
162148
)
163149

164-
return f"crop, {self.size[0]}, {self.size[1]}, {self._left}, {self._top}"
150+
top = int(torch.randint(0, input_dims[0] - self.size[0] + 1, size=()).item())
151+
self._top = top
152+
153+
left = int(torch.randint(0, input_dims[1] - self.size[1] + 1, size=()).item())
154+
self._left = left
155+
156+
return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}"
165157

166-
def _get_output_dims(
158+
def _calculate_output_dims(
167159
self, input_dims: Tuple[Optional[int], Optional[int]]
168160
) -> Tuple[Optional[int], Optional[int]]:
169161
# TODO: establish this invariant in the constructor during refactor
@@ -172,25 +164,30 @@ def _get_output_dims(
172164
height, width = input_dims
173165
if height is None:
174166
raise ValueError(
175-
"Video metadata has no height. RandomCrop can only be used when input frame dimensions are known."
167+
"Video metadata has no height. "
168+
"RandomCrop can only be used when input frame dimensions are known."
176169
)
177170
if width is None:
178171
raise ValueError(
179-
"Video metadata has no width. RandomCrop can only be used when input frame dimensions are known."
172+
"Video metadata has no width. "
173+
"RandomCrop can only be used when input frame dimensions are known."
180174
)
181175

182-
self._input_dims = (height, width)
183176
return (self.size[0], self.size[1])
184177

185178
@classmethod
186179
def _from_torchvision(
187180
cls,
188181
tv_random_crop: nn.Module,
189-
input_dims: Tuple[Optional[int], Optional[int]],
190182
):
191183
v2 = import_torchvision_transforms_v2()
192184

193-
assert isinstance(tv_random_crop, v2.RandomCrop)
185+
if not isinstance(tv_random_crop, v2.RandomCrop):
186+
raise ValueError(
187+
"Transform must be TorchVision's RandomCrop, "
188+
f"it is instead {type(tv_random_crop).__name__}. "
189+
"This should never happen, please report a bug."
190+
)
194191

195192
if tv_random_crop.padding is not None:
196193
raise ValueError(
@@ -214,25 +211,4 @@ def _from_torchvision(
214211
f"pair for the size, got {tv_random_crop.size}."
215212
)
216213

217-
height, width = input_dims
218-
if height is None:
219-
raise ValueError(
220-
"Video metadata has no height. RandomCrop can only be used when input frame dimensions are known."
221-
)
222-
if width is None:
223-
raise ValueError(
224-
"Video metadata has no width. RandomCrop can only be used when input frame dimensions are known."
225-
)
226-
227-
# Note that TorchVision v2 transforms only accept NCHW tensors.
228-
params = tv_random_crop.make_params(
229-
torch.empty(size=(3, height, width), dtype=torch.uint8)
230-
)
231-
232-
if tv_random_crop.size != (params["height"], params["width"]):
233-
raise ValueError(
234-
f"TorchVision RandomCrop's provided size, {tv_random_crop.size} "
235-
f"must match the computed size, {params['height'], params['width']}."
236-
)
237-
238-
return cls(size=tv_random_crop.size, _top=params["top"], _left=params["left"])
214+
return cls(size=tv_random_crop.size)

test/test_transform_ops.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -147,26 +147,28 @@ def test_resize_fails(self):
147147

148148
@pytest.mark.parametrize(
149149
"height_scaling_factor, width_scaling_factor",
150-
((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.25, 0.25)),
150+
((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)),
151151
)
152152
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
153+
@pytest.mark.parametrize("seed", [0, 1234])
153154
def test_random_crop_torchvision(
154155
self,
155156
height_scaling_factor,
156157
width_scaling_factor,
157158
video,
159+
seed,
158160
):
159161
height = int(video.get_height() * height_scaling_factor)
160162
width = int(video.get_width() * width_scaling_factor)
161163

162164
# We want both kinds of RandomCrop objects to get arrive at the same
163165
# locations to crop, so we need to make sure they get the same random
164166
# seed.
165-
torch.manual_seed(0)
167+
torch.manual_seed(seed)
166168
tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width))
167169
decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop])
168170

169-
torch.manual_seed(0)
171+
torch.manual_seed(seed)
170172
decoder_random_crop_tv = VideoDecoder(
171173
video.path,
172174
transforms=[v2.RandomCrop(size=(height, width))],
@@ -179,13 +181,9 @@ def test_random_crop_torchvision(
179181

180182
for frame_index in [
181183
0,
182-
int(num_frames * 0.1),
183-
int(num_frames * 0.2),
184-
int(num_frames * 0.3),
185-
int(num_frames * 0.4),
184+
int(num_frames * 0.25),
186185
int(num_frames * 0.5),
187186
int(num_frames * 0.75),
188-
int(num_frames * 0.90),
189187
num_frames - 1,
190188
]:
191189
frame_random_crop = decoder_random_crop[frame_index]
@@ -259,17 +257,6 @@ def test_crop_fails(self, error_message, params):
259257
transforms=[v2.RandomCrop(**params)],
260258
)
261259

262-
def test_tv_random_crop_nhwc_fails(self):
263-
with pytest.raises(
264-
ValueError,
265-
match="TorchVision v2 RandomCrop is only supported for NCHW",
266-
):
267-
VideoDecoder(
268-
NASA_VIDEO.path,
269-
transforms=[v2.RandomCrop(size=(100, 100))],
270-
dimension_order="NHWC",
271-
)
272-
273260
def test_transform_fails(self):
274261
with pytest.raises(
275262
ValueError,

0 commit comments

Comments
 (0)