Skip to content

Commit 8e6a8f2

Browse files
committed
Way more defensive programming
1 parent 7e43313 commit 8e6a8f2

File tree

3 files changed

+152
-80
lines changed

3 files changed

+152
-80
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def __init__(
168168

169169
device_variant = _get_cuda_backend()
170170
transform_specs = _make_transform_specs(
171-
transforms, input_dims=(self.metadata.height, self.metadata.width)
171+
transforms,
172+
input_dims=(self.metadata.height, self.metadata.width),
173+
dimension_order=dimension_order,
172174
)
173175

174176
core.add_video_stream(
@@ -452,7 +454,8 @@ def _get_and_validate_stream_metadata(
452454

453455
def _convert_to_decoder_transforms(
454456
transforms: Sequence[Union[DecoderTransform, nn.Module]],
455-
input_dims: Tuple[int, int],
457+
input_dims: Tuple[Optional[int], Optional[int]],
458+
dimension_order: Literal["NCHW", "NHWC"],
456459
) -> List[DecoderTransform]:
457460
"""Convert a sequence of transforms that may contain TorchVision transform
458461
objects into a list of only TorchCodec transform objects.
@@ -489,7 +492,16 @@ def _convert_to_decoder_transforms(
489492
input_dims = transform_tc._get_output_dims(input_dims)
490493
converted_transforms.append(transform_tc)
491494
elif isinstance(transform, v2.RandomCrop):
492-
transform_tc = RandomCrop._from_torchvision(transform, input_dims)
495+
if dimension_order != "NCHW":
496+
raise ValueError(
497+
"TorchVision v2 RandomCrop is only supported for NCHW "
498+
"dimension order. Please use the TorchCodec RandomCrop "
499+
"transform instead."
500+
)
501+
transform_tc = RandomCrop._from_torchvision(
502+
transform,
503+
input_dims,
504+
)
493505
input_dims = transform_tc._get_output_dims(input_dims)
494506
converted_transforms.append(transform_tc)
495507
else:
@@ -507,7 +519,8 @@ def _convert_to_decoder_transforms(
507519

508520
def _make_transform_specs(
509521
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
510-
input_dims: Tuple[int, int],
522+
input_dims: Tuple[Optional[int], Optional[int]],
523+
dimension_order: Literal["NCHW", "NHWC"],
511524
) -> str:
512525
"""Given a sequence of transforms, turn those into the specification string
513526
the core API expects.
@@ -527,7 +540,7 @@ def _make_transform_specs(
527540
if transforms is None:
528541
return ""
529542

530-
transforms = _convert_to_decoder_transforms(transforms, input_dims)
543+
transforms = _convert_to_decoder_transforms(transforms, input_dims, dimension_order)
531544
return ";".join([t._make_transform_spec() for t in transforms])
532545

533546

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 79 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ class DecoderTransform(ABC):
4141
def _make_transform_spec(self) -> str:
4242
pass
4343

44-
def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]:
44+
def _get_output_dims(
45+
self, input_dims: Tuple[Optional[int], Optional[int]]
46+
) -> Tuple[Optional[int], Optional[int]]:
4547
return input_dims
4648

4749

@@ -70,34 +72,39 @@ class Resize(DecoderTransform):
7072
size: Sequence[int]
7173

7274
def _make_transform_spec(self) -> str:
75+
# TODO: establish this invariant in the constructor during refactor
7376
assert len(self.size) == 2
7477
return f"resize, {self.size[0]}, {self.size[1]}"
7578

76-
def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]:
77-
return (*self.size,)
79+
def _get_output_dims(
80+
self, input_dims: Tuple[Optional[int], Optional[int]]
81+
) -> Tuple[Optional[int], Optional[int]]:
82+
# TODO: establish this invariant in the constructor during refactor
83+
assert len(self.size) == 2
84+
return (self.size[0], self.size[1])
7885

7986
@classmethod
80-
def _from_torchvision(cls, resize_tv: nn.Module):
87+
def _from_torchvision(cls, tv_resize: nn.Module):
8188
v2 = import_torchvision_transforms_v2()
8289

83-
assert isinstance(resize_tv, v2.Resize)
90+
assert isinstance(tv_resize, v2.Resize)
8491

85-
if resize_tv.interpolation is not v2.InterpolationMode.BILINEAR:
92+
if tv_resize.interpolation is not v2.InterpolationMode.BILINEAR:
8693
raise ValueError(
8794
"TorchVision Resize transform must use bilinear interpolation."
8895
)
89-
if resize_tv.antialias is False:
96+
if tv_resize.antialias is False:
9097
raise ValueError(
9198
"TorchVision Resize transform must have antialias enabled."
9299
)
93-
if resize_tv.size is None:
100+
if tv_resize.size is None:
94101
raise ValueError("TorchVision Resize transform must have a size specified.")
95-
if len(resize_tv.size) != 2:
102+
if len(tv_resize.size) != 2:
96103
raise ValueError(
97104
"TorchVision Resize transform must have a (height, width) "
98-
f"pair for the size, got {resize_tv.size}."
105+
f"pair for the size, got {tv_resize.size}."
99106
)
100-
return cls(size=resize_tv.size)
107+
return cls(size=tv_resize.size)
101108

102109

103110
@dataclass
@@ -140,52 +147,92 @@ def _make_transform_spec(self) -> str:
140147
)
141148
if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]:
142149
raise ValueError(
143-
f"Input dimensions {input_dims} are smaller than the crop size {self.size}."
150+
f"Input dimensions {self._input_dims} are smaller than the crop size {self.size}."
144151
)
145152

146153
# Note: This logic must match the logic in
147154
# torchvision.transforms.v2.RandomCrop.make_params(). Given
148155
# the same seed, they should get the same result. This is an
149156
# API guarantee with our users.
150-
self._top = torch.randint(
151-
0, self._input_dims[0] - self.size[0] + 1, size=()
157+
self._top = int(
158+
torch.randint(0, self._input_dims[0] - self.size[0] + 1, size=()).item()
152159
)
153-
self._left = torch.randint(
154-
0, self._input_dims[1] - self.size[1] + 1, size=()
160+
self._left = int(
161+
torch.randint(0, self._input_dims[1] - self.size[1] + 1, size=()).item()
155162
)
156163

157164
return f"crop, {self.size[0]}, {self.size[1]}, {self._left}, {self._top}"
158165

159-
def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]:
160-
self._input_dims = input_dims
161-
return self.size
166+
def _get_output_dims(
167+
self, input_dims: Tuple[Optional[int], Optional[int]]
168+
) -> Tuple[Optional[int], Optional[int]]:
169+
# TODO: establish this invariant in the constructor during refactor
170+
assert len(self.size) == 2
171+
172+
height, width = input_dims
173+
if height is None:
174+
raise ValueError(
175+
"Video metadata has no height. RandomCrop can only be used when input frame dimensions are known."
176+
)
177+
if width is None:
178+
raise ValueError(
179+
"Video metadata has no width. RandomCrop can only be used when input frame dimensions are known."
180+
)
181+
182+
self._input_dims = (height, width)
183+
return (self.size[0], self.size[1])
162184

163185
@classmethod
164-
def _from_torchvision(cls, random_crop_tv: nn.Module, input_dims: Tuple[int, int]):
186+
def _from_torchvision(
187+
cls,
188+
tv_random_crop: nn.Module,
189+
input_dims: Tuple[Optional[int], Optional[int]],
190+
):
165191
v2 = import_torchvision_transforms_v2()
166192

167-
assert isinstance(random_crop_tv, v2.RandomCrop)
193+
assert isinstance(tv_random_crop, v2.RandomCrop)
168194

169-
if random_crop_tv.padding is not None:
195+
if tv_random_crop.padding is not None:
170196
raise ValueError(
171197
"TorchVision RandomCrop transform must not specify padding."
172198
)
173-
if random_crop_tv.pad_if_needed is True:
199+
200+
if tv_random_crop.pad_if_needed is True:
174201
raise ValueError(
175202
"TorchVision RandomCrop transform must not specify pad_if_needed."
176203
)
177-
if random_crop_tv.fill != 0:
204+
205+
if tv_random_crop.fill != 0:
178206
raise ValueError("TorchVision RandomCrop fill must be 0.")
179-
if random_crop_tv.padding_mode != "constant":
207+
208+
if tv_random_crop.padding_mode != "constant":
180209
raise ValueError("TorchVision RandomCrop padding_mode must be constant.")
181-
if len(random_crop_tv.size) != 2:
210+
211+
if len(tv_random_crop.size) != 2:
182212
raise ValueError(
183213
"TorchVision RandcomCrop transform must have a (height, width) "
184-
f"pair for the size, got {random_crop_tv.size}."
214+
f"pair for the size, got {tv_random_crop.size}."
215+
)
216+
217+
height, width = input_dims
218+
if height is None:
219+
raise ValueError(
220+
"Video metadata has no height. RandomCrop can only be used when input frame dimensions are known."
221+
)
222+
if width is None:
223+
raise ValueError(
224+
"Video metadata has no width. RandomCrop can only be used when input frame dimensions are known."
185225
)
186-
params = random_crop_tv.make_params(
187-
# TODO: deal with NCHW versus NHWC; video decoder knows
188-
torch.empty(size=(3, *input_dims), dtype=torch.uint8)
226+
227+
# Note that TorchVision v2 transforms only accept NCHW tensors.
228+
params = tv_random_crop.make_params(
229+
torch.empty(size=(3, height, width), dtype=torch.uint8)
189230
)
190-
assert random_crop_tv.size == (params["height"], params["width"])
191-
return cls(size=random_crop_tv.size, _top=params["top"], _left=params["left"])
231+
232+
if tv_random_crop.size != (params["height"], params["width"]):
233+
raise ValueError(
234+
f"TorchVision RandomCrop's provided size, {tv_random_crop.size} "
235+
f"must match the computed size, {params['height'], params['width']}."
236+
)
237+
238+
return cls(size=tv_random_crop.size, _top=params["top"], _left=params["left"])

test/test_transform_ops.py

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ def test_resize_fails(self):
151151
)
152152
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
153153
def test_random_crop_torchvision(
154-
self, video, height_scaling_factor, width_scaling_factor
154+
self,
155+
height_scaling_factor,
156+
width_scaling_factor,
157+
video,
155158
):
156159
height = int(video.get_height() * height_scaling_factor)
157160
width = int(video.get_width() * width_scaling_factor)
@@ -165,7 +168,8 @@ def test_random_crop_torchvision(
165168

166169
torch.manual_seed(0)
167170
decoder_random_crop_tv = VideoDecoder(
168-
video.path, transforms=[v2.RandomCrop(size=(height, width))]
171+
video.path,
172+
transforms=[v2.RandomCrop(size=(height, width))],
169173
)
170174

171175
decoder_full = VideoDecoder(video.path)
@@ -201,61 +205,69 @@ def test_random_crop_torchvision(
201205
)
202206
assert_frames_equal(frame_random_crop, frame_tv)
203207

204-
def test_crop_fails(self):
205-
with pytest.raises(
206-
ValueError,
207-
match="must not specify padding",
208-
):
209-
VideoDecoder(
210-
NASA_VIDEO.path,
211-
transforms=[
212-
v2.RandomCrop(
213-
size=(100, 100),
214-
padding=255,
215-
)
216-
],
217-
)
208+
@pytest.mark.parametrize(
209+
"height_scaling_factor, width_scaling_factor",
210+
((0.25, 0.1), (0.25, 0.25)),
211+
)
212+
def test_random_crop_nhwc(
213+
self,
214+
height_scaling_factor,
215+
width_scaling_factor,
216+
):
217+
height = int(TEST_SRC_2_720P.get_height() * height_scaling_factor)
218+
width = int(TEST_SRC_2_720P.get_width() * width_scaling_factor)
218219

219-
with pytest.raises(
220-
ValueError,
221-
match="must not specify pad_if_needed",
222-
):
223-
VideoDecoder(
224-
NASA_VIDEO.path,
225-
transforms=[
226-
v2.RandomCrop(
227-
size=(100, 100),
228-
pad_if_needed=True,
229-
)
230-
],
231-
)
220+
decoder = VideoDecoder(
221+
TEST_SRC_2_720P.path,
222+
transforms=[torchcodec.transforms.RandomCrop(size=(height, width))],
223+
dimension_order="NHWC",
224+
)
225+
226+
num_frames = len(decoder)
227+
for frame_index in [
228+
0,
229+
int(num_frames * 0.25),
230+
int(num_frames * 0.5),
231+
int(num_frames * 0.75),
232+
num_frames - 1,
233+
]:
234+
frame = decoder[frame_index]
235+
assert frame.shape == (height, width, 3)
232236

237+
@pytest.mark.parametrize(
238+
"error_message, params",
239+
(
240+
("must not specify padding", dict(size=(100, 100), padding=255)),
241+
(
242+
"must not specify pad_if_needed",
243+
dict(size=(100, 100), pad_if_needed=True),
244+
),
245+
("fill must be 0", dict(size=(100, 100), fill=255)),
246+
(
247+
"padding_mode must be constant",
248+
dict(size=(100, 100), padding_mode="edge"),
249+
),
250+
),
251+
)
252+
def test_crop_fails(self, error_message, params):
233253
with pytest.raises(
234254
ValueError,
235-
match="fill must be 0",
255+
match=error_message,
236256
):
237257
VideoDecoder(
238258
NASA_VIDEO.path,
239-
transforms=[
240-
v2.RandomCrop(
241-
size=(100, 100),
242-
fill=255,
243-
)
244-
],
259+
transforms=[v2.RandomCrop(**params)],
245260
)
246261

262+
def test_tv_random_crop_nhwc_fails(self):
247263
with pytest.raises(
248264
ValueError,
249-
match="padding_mode must be constant",
265+
match="TorchVision v2 RandomCrop is only supported for NCHW",
250266
):
251267
VideoDecoder(
252268
NASA_VIDEO.path,
253-
transforms=[
254-
v2.RandomCrop(
255-
size=(100, 100),
256-
padding_mode="edge",
257-
)
258-
],
269+
transforms=[v2.RandomCrop(size=(100, 100))],
270+
dimension_order="NHWC",
259271
)
260272

261273
def test_transform_fails(self):

0 commit comments

Comments
 (0)