Skip to content

Commit 5017760

Browse files
committed
Merge branch 'random_crop' into refactor_decoder_transforms
2 parents a92d5f0 + be8ed26 commit 5017760

File tree

6 files changed

+155
-61
lines changed

6 files changed

+155
-61
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,14 @@ void SingleStreamDecoder::addVideoStream(
545545

546546
metadataDims_ =
547547
FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
548+
FrameDims currInputDims = metadataDims_;
548549
for (auto& transform : transforms) {
549550
TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
550551
if (transform->getOutputFrameDims().has_value()) {
551552
resizedOutputDims_ = transform->getOutputFrameDims().value();
552553
}
553-
transform->validate(streamMetadata);
554+
transform->validate(currInputDims);
555+
currInputDims = resizedOutputDims_.value_or(metadataDims_);
554556

555557
// Note that we are claiming ownership of the transform objects passed in to
556558
// us.

src/torchcodec/_core/Transform.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,45 @@ std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
5353
return outputDims_;
5454
}
5555

56-
void CropTransform::validate(const StreamMetadata& streamMetadata) const {
57-
TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
56+
void CropTransform::validate(const FrameDims& inputDims) const {
5857
TORCH_CHECK(
59-
x_ + outputDims_.width <= streamMetadata.width,
60-
"Crop x position out of bounds")
61-
TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
58+
outputDims_.height <= inputDims.height,
59+
"Crop output height (",
60+
outputDims_.height,
61+
") is greater than input height (",
62+
inputDims.height,
63+
")");
6264
TORCH_CHECK(
63-
y_ + outputDims_.height <= streamMetadata.height,
64-
"Crop y position out of bounds");
65+
outputDims_.width <= inputDims.width,
66+
"Crop output width (",
67+
outputDims_.width,
68+
") is greater than input width (",
69+
inputDims.width,
70+
")");
71+
TORCH_CHECK(
72+
x_ <= inputDims.width,
73+
"Crop x start position, ",
74+
x_,
75+
", out of bounds of input width, ",
76+
inputDims.width);
77+
TORCH_CHECK(
78+
x_ + outputDims_.width <= inputDims.width,
79+
"Crop x end position, ",
80+
x_ + outputDims_.width,
81+
", out of bounds of input width ",
82+
inputDims.width);
83+
TORCH_CHECK(
84+
y_ <= inputDims.height,
85+
"Crop y start position, ",
86+
y_,
87+
", out of bounds of input height, ",
88+
inputDims.height);
89+
TORCH_CHECK(
90+
y_ + outputDims_.height <= inputDims.height,
91+
"Crop y end position, ",
92+
y_ + outputDims_.height,
93+
", out of bounds of input height ",
94+
inputDims.height);
6595
}
6696

6797
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ class Transform {
3636
//
3737
// Note that the validation function does not return anything. We expect
3838
// invalid configurations to throw an exception.
39-
virtual void validate(
40-
[[maybe_unused]] const StreamMetadata& streamMetadata) const {}
39+
virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {}
4140
};
4241

4342
class ResizeTransform : public Transform {
@@ -64,7 +63,7 @@ class CropTransform : public Transform {
6463

6564
std::string getFilterGraphCpu() const override;
6665
std::optional<FrameDims> getOutputFrameDims() const override;
67-
void validate(const StreamMetadata& streamMetadata) const override;
66+
void validate(const FrameDims& inputDims) const override;
6867

6968
private:
7069
FrameDims outputDims_;

src/torchcodec/decoders/_video_decoder.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -514,36 +514,35 @@ def _make_transform_specs(
514514
# dimensions from its input dimensions. We store these with the converted
515515
# transform, to be all used together when we generate the specs.
516516
converted_transforms: list[
517-
Tuple[DecoderTransform, Tuple[Optional[int], Optional[int]]]
517+
Tuple[
518+
DecoderTransform,
519+
# A (height, width) pair where the values may be missing.
520+
Tuple[Optional[int], Optional[int]],
521+
]
518522
] = []
519523
curr_input_dims = input_dims
520524
for transform in transforms:
521-
if isinstance(transform, DecoderTransform):
522-
output_dims = transform._calculate_output_dims(curr_input_dims)
523-
converted_transforms.append((transform, curr_input_dims))
524-
else:
525+
if not isinstance(transform, DecoderTransform):
525526
if not tv_available:
526527
raise ValueError(
527528
f"The supplied transform, {transform}, is not a TorchCodec "
528529
" DecoderTransform. TorchCodec also accepts TorchVision "
529530
"v2 transforms, but TorchVision is not installed."
530531
)
531532
elif isinstance(transform, v2.Resize):
532-
tc_transform = Resize._from_torchvision(transform)
533-
output_dims = tc_transform._calculate_output_dims(curr_input_dims)
534-
converted_transforms.append((tc_transform, curr_input_dims))
533+
transform = Resize._from_torchvision(transform)
535534
elif isinstance(transform, v2.RandomCrop):
536-
tc_transform = RandomCrop._from_torchvision(transform)
537-
output_dims = tc_transform._calculate_output_dims(curr_input_dims)
538-
converted_transforms.append((tc_transform, curr_input_dims))
535+
transform = RandomCrop._from_torchvision(transform)
539536
else:
540537
raise ValueError(
541538
f"Unsupported transform: {transform}. Transforms must be "
542539
"either a TorchCodec DecoderTransform or a TorchVision "
543540
"v2 transform."
544541
)
545542

546-
curr_input_dims = output_dims
543+
converted_transforms.append((transform, curr_input_dims))
544+
output_dims = transform._get_output_dims()
545+
curr_input_dims = output_dims if output_dims is not None else curr_input_dims
547546

548547
return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])
549548

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,44 @@ class DecoderTransform(ABC):
3939
def _make_transform_spec(
4040
self, input_dims: Tuple[Optional[int], Optional[int]]
4141
) -> str:
42+
"""Makes the transform spec that is used by the `VideoDecoder`.
43+
44+
Args:
45+
input_dims (Tuple[Optional[int], Optional[int]]): The dimensions of
46+
the input frame in the form (height, width). We cannot know the
47+
dimensions at object construction time because it's dependent on
48+
the video being decoded and upstream transforms in the same
49+
transform pipeline. Not all transforms need to know this; those
50+
that don't will ignore it. The individual values in the tuple are
51+
optional because the original values come from file metadata which
52+
may be missing. We maintain the optionality throughout the APIs so
53+
that we can decide as late as possible that it's necessary for the
54+
values to exist. That is, if the values are missing from the
55+
metadata and we have transforms which ignore the input dimensions,
56+
we want that to still work.
57+
58+
Note: This method is the moral equivalent of TorchVision's
59+
`Transform.make_params()`.
60+
61+
Returns:
62+
str: A string which contains the spec for the transform that the
63+
`VideoDecoder` knows what to do with.
64+
"""
4265
pass
4366

44-
def _calculate_output_dims(
45-
self, input_dims: Tuple[Optional[int], Optional[int]]
46-
) -> Tuple[Optional[int], Optional[int]]:
47-
return input_dims
67+
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
68+
"""Get the dimensions of the output frame.
69+
70+
Transforms that change the frame dimensions need to override this
71+
method. Transforms that don't change the frame dimensions can rely on
72+
this default implementation.
73+
74+
Returns:
75+
Optional[Tuple[Optional[int], Optional[int]]]: The output dimensions.
76+
- None: The output dimensions are the same as the input dimensions.
77+
- (int, int): The (height, width) of the output frame.
78+
"""
79+
return None
4880

4981

5082
def import_torchvision_transforms_v2() -> ModuleType:
@@ -64,7 +96,7 @@ class Resize(DecoderTransform):
6496
Interpolation is always bilinear. Anti-aliasing is always on.
6597
6698
Args:
67-
size: (sequence of int): Desired output size. Must be a sequence of
99+
size (Sequence[int]): Desired output size. Must be a sequence of
68100
the form (height, width).
69101
"""
70102

@@ -81,9 +113,7 @@ def _make_transform_spec(
81113
) -> str:
82114
return f"resize, {self.size[0]}, {self.size[1]}"
83115

84-
def _calculate_output_dims(
85-
self, input_dims: Tuple[Optional[int], Optional[int]]
86-
) -> Tuple[Optional[int], Optional[int]]:
116+
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
87117
return (self.size[0], self.size[1])
88118

89119
@classmethod
@@ -116,13 +146,13 @@ class RandomCrop(DecoderTransform):
116146
Complementary TorchVision transform: :class:`~torchvision.transforms.v2.RandomCrop`.
117147
Padding of all kinds is disabled. The random location within the frame is
118148
determined during the initialization of the
119-
:class:~`torchcodec.decoders.VideoDecoder` object that owns this transform.
149+
:class:`~torchcodec.decoders.VideoDecoder` object that owns this transform.
120150
As a consequence, each decoded frame in the video will be cropped at the
121151
same location. Videos with variable resolution may result in undefined
122152
behavior.
123153
124154
Args:
125-
size: (sequence of int): Desired output size. Must be a sequence of
155+
size (Sequence[int]): Desired output size. Must be a sequence of
126156
the form (height, width).
127157
"""
128158

@@ -159,28 +189,11 @@ def _make_transform_spec(
159189
)
160190

161191
top = int(torch.randint(0, height - self.size[0] + 1, size=()).item())
162-
self._top = top
163-
164192
left = int(torch.randint(0, width - self.size[1] + 1, size=()).item())
165-
self._left = left
166193

167194
return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}"
168195

169-
def _calculate_output_dims(
170-
self, input_dims: Tuple[Optional[int], Optional[int]]
171-
) -> Tuple[Optional[int], Optional[int]]:
172-
height, width = input_dims
173-
if height is None:
174-
raise ValueError(
175-
"Video metadata has no height. "
176-
"RandomCrop can only be used when input frame dimensions are known."
177-
)
178-
if width is None:
179-
raise ValueError(
180-
"Video metadata has no width. "
181-
"RandomCrop can only be used when input frame dimensions are known."
182-
)
183-
196+
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
184197
return (self.size[0], self.size[1])
185198

186199
@classmethod

test/test_transform_ops.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,15 @@ def test_random_crop_torchvision(
172172

173173
# We want both kinds of RandomCrop objects to get arrive at the same
174174
# locations to crop, so we need to make sure they get the same random
175-
# seed.
175+
# seed. It's used in RandomCrop's _make_transform_spec() method, called
176+
# by the VideoDecoder.
176177
torch.manual_seed(seed)
177178
tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width))
178179
decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop])
179180

181+
# Resetting manual seed for when TorchCodec's RandomCrop, created from
182+
# the TorchVision RandomCrop, is used inside of the VideoDecoder. It
183+
# needs to match the call above.
180184
torch.manual_seed(seed)
181185
decoder_random_crop_tv = VideoDecoder(
182186
video.path,
@@ -202,14 +206,11 @@ def test_random_crop_torchvision(
202206
expected_shape = (video.get_num_color_channels(), height, width)
203207
assert frame_random_crop_tv.shape == expected_shape
204208

209+
# Resetting manual seed to make sure the invocation of the
210+
# TorchVision RandomCrop matches the two calls above.
211+
torch.manual_seed(seed)
205212
frame_full = decoder_full[frame_index]
206-
frame_tv = v2.functional.crop(
207-
frame_full,
208-
top=tc_random_crop._top,
209-
left=tc_random_crop._left,
210-
height=tc_random_crop.size[0],
211-
width=tc_random_crop.size[1],
212-
)
213+
frame_tv = v2.RandomCrop(size=(height, width))(frame_full)
213214
assert_frames_equal(frame_random_crop, frame_tv)
214215

215216
@pytest.mark.parametrize(
@@ -266,6 +267,56 @@ def test_crop_fails(self, error_message, params):
266267
transforms=[v2.RandomCrop(**params)],
267268
)
268269

270+
@pytest.mark.parametrize("seed", [0, 314])
271+
def test_random_crop_reusable_objects(self, seed):
272+
torch.manual_seed(seed)
273+
random_crop = torchcodec.transforms.RandomCrop(size=(99, 99))
274+
275+
# Create a spec which causes us to calculate the random crop location.
276+
first_spec = random_crop._make_transform_spec((888, 888))
277+
278+
# Create a spec again, which should calculate a different random crop
279+
# location. Despite having the same image size, the specs should be
280+
# different because the crop should be at a different location
281+
second_spec = random_crop._make_transform_spec((888, 888))
282+
assert first_spec != second_spec
283+
284+
# Create a spec again, but with a different image size. The specs should
285+
# obviously be different, but the original image size should not be in
286+
# the spec at all.
287+
third_spec = random_crop._make_transform_spec((777, 777))
288+
assert third_spec != first_spec
289+
assert "888" not in third_spec
290+
291+
@pytest.mark.parametrize(
292+
"resize, random_crop",
293+
[
294+
(torchcodec.transforms.Resize, torchcodec.transforms.RandomCrop),
295+
(v2.Resize, v2.RandomCrop),
296+
],
297+
)
298+
def test_transform_pipeline(self, resize, random_crop):
299+
decoder = VideoDecoder(
300+
TEST_SRC_2_720P.path,
301+
transforms=[
302+
# resized to bigger than original
303+
resize(size=(2160, 3840)),
304+
# crop to smaller than the resize, but still bigger than original
305+
random_crop(size=(1080, 1920)),
306+
],
307+
)
308+
309+
num_frames = len(decoder)
310+
for frame_index in [
311+
0,
312+
int(num_frames * 0.25),
313+
int(num_frames * 0.5),
314+
int(num_frames * 0.75),
315+
num_frames - 1,
316+
]:
317+
frame = decoder[frame_index]
318+
assert frame.shape == (TEST_SRC_2_720P.get_num_color_channels(), 1080, 1920)
319+
269320
def test_transform_fails(self):
270321
with pytest.raises(
271322
ValueError,
@@ -528,14 +579,14 @@ def test_crop_transform_fails(self):
528579

529580
with pytest.raises(
530581
RuntimeError,
531-
match="x position out of bounds",
582+
match="x start position, 9999, out of bounds",
532583
):
533584
decoder = create_from_file(str(NASA_VIDEO.path))
534585
add_video_stream(decoder, transform_specs="crop, 100, 100, 9999, 100")
535586

536587
with pytest.raises(
537588
RuntimeError,
538-
match="y position out of bounds",
589+
match=r"Crop output height \(999\) is greater than input height \(270\)",
539590
):
540591
decoder = create_from_file(str(NASA_VIDEO.path))
541592
add_video_stream(decoder, transform_specs="crop, 999, 100, 100, 100")

0 commit comments

Comments
 (0)