Skip to content

Commit aa15765

Browse files
committed
It... works?
1 parent af2e1ab commit aa15765

File tree

4 files changed

+93
-10
lines changed

4 files changed

+93
-10
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
create_decoder,
2020
ERROR_REPORTING_INSTRUCTIONS,
2121
)
22-
from torchcodec.transforms import DecoderTransform, Resize
22+
from torchcodec.transforms import DecoderTransform, RandomCrop, Resize
2323

2424

2525
class VideoDecoder:
@@ -167,7 +167,9 @@ def __init__(
167167
device = str(device)
168168

169169
device_variant = _get_cuda_backend()
170-
transform_specs = _make_transform_specs(transforms)
170+
transform_specs = _make_transform_specs(
171+
transforms, input_dims=(self.metadata.height, self.metadata.width)
172+
)
171173

172174
core.add_video_stream(
173175
self._decoder,
@@ -450,6 +452,7 @@ def _get_and_validate_stream_metadata(
450452

451453
def _convert_to_decoder_transforms(
452454
transforms: Sequence[Union[DecoderTransform, nn.Module]],
455+
input_dims: Tuple[int, int],
453456
) -> List[DecoderTransform]:
454457
"""Convert a sequence of transforms that may contain TorchVision transform
455458
objects into a list of only TorchCodec transform objects.
@@ -482,21 +485,29 @@ def _convert_to_decoder_transforms(
482485
"v2 transforms, but TorchVision is not installed."
483486
)
484487
elif isinstance(transform, v2.Resize):
485-
converted_transforms.append(Resize._from_torchvision(transform))
488+
transform_tc = Resize._from_torchvision(transform)
489+
input_dims = transform_tc._get_output_dims(input_dims)
490+
converted_transforms.append(transform_tc)
491+
elif isinstance(transform, v2.RandomCrop):
492+
transform_tc = RandomCrop._from_torchvision(transform, input_dims)
493+
input_dims = transform_tc._get_output_dims(input_dims)
494+
converted_transforms.append(transform_tc)
486495
else:
487496
raise ValueError(
488497
f"Unsupported transform: {transform}. Transforms must be "
489498
"either a TorchCodec DecoderTransform or a TorchVision "
490499
"v2 transform."
491500
)
492501
else:
502+
intput_dims = transform._get_output_dims(input_dims)
493503
converted_transforms.append(transform)
494504

495505
return converted_transforms
496506

497507

498508
def _make_transform_specs(
499509
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
510+
input_dims: Tuple[int, int],
500511
) -> str:
501512
"""Given a sequence of transforms, turn those into the specification string
502513
the core API expects.
@@ -516,7 +527,7 @@ def _make_transform_specs(
516527
if transforms is None:
517528
return ""
518529

519-
transforms = _convert_to_decoder_transforms(transforms)
530+
transforms = _convert_to_decoder_transforms(transforms, input_dims)
520531
return ";".join([t._make_transform_spec() for t in transforms])
521532

522533

src/torchcodec/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from ._decoder_transforms import DecoderTransform, Resize # noqa
7+
from ._decoder_transforms import DecoderTransform, RandomCrop, Resize # noqa

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass
99
from types import ModuleType
10-
from typing import Sequence
10+
from typing import Optional, Sequence, Tuple
1111

12+
import torch
1213
from torch import nn
1314

1415

@@ -40,6 +41,9 @@ class DecoderTransform(ABC):
4041
def _make_transform_spec(self) -> str:
4142
pass
4243

44+
def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]:
45+
return input_dims
46+
4347

4448
def import_torchvision_transforms_v2() -> ModuleType:
4549
try:
@@ -69,6 +73,9 @@ def _make_transform_spec(self) -> str:
6973
assert len(self.size) == 2
7074
return f"resize, {self.size[0]}, {self.size[1]}"
7175

76+
def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]:
77+
return self.size
78+
7279
@classmethod
7380
def _from_torchvision(cls, resize_tv: nn.Module):
7481
v2 = import_torchvision_transforms_v2()
@@ -92,19 +99,38 @@ def _from_torchvision(cls, resize_tv: nn.Module):
9299
)
93100
return cls(size=resize_tv.size)
94101

102+
95103
@dataclass
96104
class RandomCrop(DecoderTransform):
97105

98106
size: Sequence[int]
99107
_top: Optional[int] = None
100108
_left: Optional[int] = None
109+
_input_dims: Optional[Tuple[int, int]] = None
101110

102111
def _make_transform_spec(self) -> str:
103112
assert len(self.size) == 2
104-
return f"crop, {self.size[0]}, {self.size[1]}, {_left}, {_top}"
113+
if self._top is None or self._left is None:
114+
assert self._input_dims is not None
115+
if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]:
116+
raise ValueError(
117+
f"Input dimensions {input_dims} are smaller than the crop size {self.size}."
118+
)
119+
self._top = torch.randint(
120+
0, self._input_dims[0] - self.size[0] + 1, size=()
121+
)
122+
self._left = torch.randint(
123+
0, self._input_dims[1] - self.size[1] + 1, size=()
124+
)
125+
126+
return f"crop, {self.size[0]}, {self.size[1]}, {self._left}, {self._top}"
127+
128+
def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]:
129+
self._input_dims = input_dims
130+
return self.size
105131

106132
@classmethod
107-
def _from_torchvision(cls, random_crop_tv: nn.Module):
133+
def _from_torchvision(cls, random_crop_tv: nn.Module, input_dims: Tuple[int, int]):
108134
v2 = import_torchvision_transforms_v2()
109135

110136
assert isinstance(random_crop_tv, v2.RandomCrop)
@@ -128,5 +154,8 @@ def _from_torchvision(cls, random_crop_tv: nn.Module):
128154
"TorchVision RandcomCrop transform must have a (height, width) "
129155
f"pair for the size, got {random_crop_tv.size}."
130156
)
131-
params = random_crop_tv.make_params([])
132-
return cls(size=random_crop_tv.size)
157+
params = random_crop_tv.make_params(
158+
torch.empty(size=(3, *input_dims), dtype=torch.uint8)
159+
)
160+
assert random_crop_tv.size == (params["height"], params["width"])
161+
return cls(size=random_crop_tv.size, _top=params["top"], _left=params["left"])

test/test_transform_ops.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,49 @@ def test_resize_fails(self):
145145
):
146146
VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))])
147147

148+
@pytest.mark.parametrize(
149+
"height_scaling_factor, width_scaling_factor",
150+
((0.5, 0.5), (0.25, 0.1)),
151+
)
152+
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
153+
def test_random_crop_torchvision(
154+
self, video, height_scaling_factor, width_scaling_factor
155+
):
156+
height = int(video.get_height() * height_scaling_factor)
157+
width = int(video.get_width() * width_scaling_factor)
158+
159+
torch.manual_seed(0)
160+
tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width))
161+
decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop])
162+
163+
torch.manual_seed(0)
164+
decoder_random_crop_tv = VideoDecoder(
165+
video.path, transforms=[v2.RandomCrop(size=(height, width))]
166+
)
167+
168+
decoder_full = VideoDecoder(video.path)
169+
170+
num_frames = len(decoder_random_crop_tv)
171+
assert num_frames == len(decoder_full)
172+
173+
for frame_index in [
174+
0,
175+
int(num_frames * 0.1),
176+
int(num_frames * 0.2),
177+
int(num_frames * 0.3),
178+
int(num_frames * 0.4),
179+
int(num_frames * 0.5),
180+
int(num_frames * 0.75),
181+
int(num_frames * 0.90),
182+
num_frames - 1,
183+
]:
184+
frame_random_crop = decoder_random_crop[frame_index]
185+
frame_random_crop_tv = decoder_random_crop_tv[frame_index]
186+
assert_frames_equal(frame_random_crop, frame_random_crop_tv)
187+
188+
expected_shape = (video.get_num_color_channels(), height, width)
189+
assert frame_random_crop_tv.shape == expected_shape
190+
148191
def test_transform_fails(self):
149192
with pytest.raises(
150193
ValueError,

0 commit comments

Comments
 (0)