Skip to content

Commit 7bc569a

Browse files
authored
Implement CenterCrop (#1094)
1 parent 1b13e58 commit 7bc569a

File tree

8 files changed

+175
-30
lines changed

8 files changed

+175
-30
lines changed

docs/source/api_ref_transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ For a tutorial, see: TODO_DECODER_TRANSFORMS_TUTORIAL.
1414
:template: dataclass.rst
1515

1616
DecoderTransform
17+
CenterCrop
1718
RandomCrop
1819
Resize

src/torchcodec/_core/Transform.cpp

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,22 @@ std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
3737
return outputDims_;
3838
}
3939

40+
CropTransform::CropTransform(const FrameDims& dims) : outputDims_(dims) {}
41+
4042
CropTransform::CropTransform(const FrameDims& dims, int x, int y)
4143
: outputDims_(dims), x_(x), y_(y) {
4244
TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);
4345
TORCH_CHECK(y_ >= 0, "Crop y position must be >= 0, got: ", y_);
4446
}
4547

4648
std::string CropTransform::getFilterGraphCpu() const {
49+
// For the FFmpeg filter crop, if the x and y coordinates are left
50+
// unspecified, it defaults to a center crop.
51+
std::string coordinates = x_.has_value()
52+
? (":" + std::to_string(x_.value()) + ":" + std::to_string(y_.value()))
53+
: "";
4754
return "crop=" + std::to_string(outputDims_.width) + ":" +
48-
std::to_string(outputDims_.height) + ":" + std::to_string(x_) + ":" +
49-
std::to_string(y_) + ":exact=1";
55+
std::to_string(outputDims_.height) + coordinates + ":exact=1";
5056
}
5157

5258
std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
@@ -69,29 +75,34 @@ void CropTransform::validate(const FrameDims& inputDims) const {
6975
inputDims.width,
7076
")");
7177
TORCH_CHECK(
72-
x_ <= inputDims.width,
73-
"Crop x start position, ",
74-
x_,
75-
", out of bounds of input width, ",
76-
inputDims.width);
77-
TORCH_CHECK(
78-
x_ + outputDims_.width <= inputDims.width,
79-
"Crop x end position, ",
80-
x_ + outputDims_.width,
81-
", out of bounds of input width ",
82-
inputDims.width);
83-
TORCH_CHECK(
84-
y_ <= inputDims.height,
85-
"Crop y start position, ",
86-
y_,
87-
", out of bounds of input height, ",
88-
inputDims.height);
89-
TORCH_CHECK(
90-
y_ + outputDims_.height <= inputDims.height,
91-
"Crop y end position, ",
92-
y_ + outputDims_.height,
93-
", out of bounds of input height ",
94-
inputDims.height);
78+
x_.has_value() == y_.has_value(),
79+
"Crop x and y values must be both set or both unset");
80+
if (x_.has_value()) {
81+
TORCH_CHECK(
82+
x_.value() <= inputDims.width,
83+
"Crop x start position, ",
84+
x_.value(),
85+
", out of bounds of input width, ",
86+
inputDims.width);
87+
TORCH_CHECK(
88+
x_.value() + outputDims_.width <= inputDims.width,
89+
"Crop x end position, ",
90+
x_.value() + outputDims_.width,
91+
", out of bounds of input width ",
92+
inputDims.width);
93+
TORCH_CHECK(
94+
y_.value() <= inputDims.height,
95+
"Crop y start position, ",
96+
y_.value(),
97+
", out of bounds of input height, ",
98+
inputDims.height);
99+
TORCH_CHECK(
100+
y_.value() + outputDims_.height <= inputDims.height,
101+
"Crop y end position, ",
102+
y_.value() + outputDims_.height,
103+
", out of bounds of input height ",
104+
inputDims.height);
105+
}
95106
}
96107

97108
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,17 @@ class CropTransform : public Transform {
6161
public:
6262
CropTransform(const FrameDims& dims, int x, int y);
6363

64+
// Becomes a center crop if x and y are not specified.
65+
CropTransform(const FrameDims& dims);
66+
6467
std::string getFilterGraphCpu() const override;
6568
std::optional<FrameDims> getOutputFrameDims() const override;
6669
void validate(const FrameDims& inputDims) const override;
6770

6871
private:
6972
FrameDims outputDims_;
70-
int x_;
71-
int y_;
73+
std::optional<int> x_;
74+
std::optional<int> y_;
7275
};
7376

7477
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,23 @@ Transform* makeCropTransform(
288288
return new CropTransform(FrameDims(height, width), x, y);
289289
}
290290

291+
// CenterCrop transform specs take the form:
292+
//
293+
// "center_crop, <height>, <width>"
294+
//
295+
// Where "center_crop" is the string literal and <height>, <width> are
296+
// positive integers. Note that we follow the PyTorch convention of (height,
297+
// width) for specifying image dimensions; FFmpeg uses (width, height).
298+
Transform* makeCenterCropTransform(
299+
const std::vector<std::string>& cropTransformSpec) {
300+
TORCH_CHECK(
301+
cropTransformSpec.size() == 3,
302+
"cropTransformSpec must have 3 elements including its name");
303+
int height = checkedToPositiveInt(cropTransformSpec[1]);
304+
int width = checkedToPositiveInt(cropTransformSpec[2]);
305+
return new CropTransform(FrameDims(height, width));
306+
}
307+
291308
std::vector<std::string> split(const std::string& str, char delimiter) {
292309
std::vector<std::string> tokens;
293310
std::string token;
@@ -317,6 +334,8 @@ std::vector<Transform*> makeTransforms(const std::string& transformSpecsRaw) {
317334
transforms.push_back(makeResizeTransform(transformSpec));
318335
} else if (name == "crop") {
319336
transforms.push_back(makeCropTransform(transformSpec));
337+
} else if (name == "center_crop") {
338+
transforms.push_back(makeCenterCropTransform(transformSpec));
320339
} else {
321340
TORCH_CHECK(false, "Invalid transform name: " + name);
322341
}

src/torchcodec/decoders/_video_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
create_decoder,
2020
ERROR_REPORTING_INSTRUCTIONS,
2121
)
22-
from torchcodec.transforms import DecoderTransform, RandomCrop, Resize
22+
from torchcodec.transforms import CenterCrop, DecoderTransform, RandomCrop, Resize
2323

2424

2525
class VideoDecoder:
@@ -531,6 +531,8 @@ def _make_transform_specs(
531531
)
532532
elif isinstance(transform, v2.Resize):
533533
transform = Resize._from_torchvision(transform)
534+
elif isinstance(transform, v2.CenterCrop):
535+
transform = CenterCrop._from_torchvision(transform)
534536
elif isinstance(transform, v2.RandomCrop):
535537
transform = RandomCrop._from_torchvision(transform)
536538
else:

src/torchcodec/transforms/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from ._decoder_transforms import DecoderTransform, RandomCrop, Resize # noqa
7+
from ._decoder_transforms import ( # noqa
8+
CenterCrop,
9+
DecoderTransform,
10+
RandomCrop,
11+
Resize,
12+
)

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,55 @@ def _from_torchvision(cls, tv_resize: nn.Module):
140140
return cls(size=tv_resize.size)
141141

142142

143+
class CenterCrop(DecoderTransform):
144+
"""Crop the decoded frame to a given size in the center of the frame.
145+
146+
Complementary TorchVision transform: :class:`~torchvision.transforms.v2.CenterCrop`.
147+
148+
Args:
149+
size (Sequence[int]): Desired output size. Must be a sequence of
150+
the form (height, width).
151+
"""
152+
153+
def __init__(self, size: Sequence[int]):
154+
if len(size) != 2:
155+
raise ValueError(
156+
"CenterCrop transform must have a (height, width) "
157+
f"pair for the size, got {size}."
158+
)
159+
self.size = size
160+
161+
def _make_transform_spec(
162+
self, input_dims: Tuple[Optional[int], Optional[int]]
163+
) -> str:
164+
return f"center_crop, {self.size[0]}, {self.size[1]}"
165+
166+
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
167+
return (self.size[0], self.size[1])
168+
169+
@classmethod
170+
def _from_torchvision(
171+
cls,
172+
tv_center_crop: nn.Module,
173+
):
174+
v2 = import_torchvision_transforms_v2()
175+
176+
if not isinstance(tv_center_crop, v2.CenterCrop):
177+
raise ValueError(
178+
"Transform must be TorchVision's CenterCrop, "
179+
f"it is instead {type(tv_center_crop).__name__}. "
180+
"This should never happen, please report a bug."
181+
)
182+
183+
if len(tv_center_crop.size) != 2:
184+
raise ValueError(
185+
"TorchVision CenterCrop transform must have a (height, width) "
186+
f"pair for the size, got {tv_center_crop.size}."
187+
)
188+
189+
return cls(size=tv_center_crop.size)
190+
191+
143192
class RandomCrop(DecoderTransform):
144193
"""Crop the decoded frame to a given size at a random location in the frame.
145194

test/test_transform_ops.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,61 @@ def test_resize_fails(self):
154154
transforms=[torchcodec.transforms.Resize(size=(100, 100, 100))],
155155
)
156156

157+
@pytest.mark.parametrize(
158+
"height_scaling_factor, width_scaling_factor",
159+
((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)),
160+
)
161+
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
162+
def test_center_crop_torchvision(
163+
self,
164+
height_scaling_factor,
165+
width_scaling_factor,
166+
video,
167+
):
168+
height = int(video.get_height() * height_scaling_factor)
169+
width = int(video.get_width() * width_scaling_factor)
170+
171+
tc_center_crop = torchcodec.transforms.CenterCrop(size=(height, width))
172+
decoder_center_crop = VideoDecoder(video.path, transforms=[tc_center_crop])
173+
174+
decoder_center_crop_tv = VideoDecoder(
175+
video.path,
176+
transforms=[v2.CenterCrop(size=(height, width))],
177+
)
178+
179+
decoder_full = VideoDecoder(video.path)
180+
181+
num_frames = len(decoder_center_crop_tv)
182+
assert num_frames == len(decoder_full)
183+
184+
for frame_index in [
185+
0,
186+
int(num_frames * 0.25),
187+
int(num_frames * 0.5),
188+
int(num_frames * 0.75),
189+
num_frames - 1,
190+
]:
191+
frame_center_crop = decoder_center_crop[frame_index]
192+
frame_center_crop_tv = decoder_center_crop_tv[frame_index]
193+
assert_frames_equal(frame_center_crop, frame_center_crop_tv)
194+
195+
expected_shape = (video.get_num_color_channels(), height, width)
196+
assert frame_center_crop_tv.shape == expected_shape
197+
198+
frame_full = decoder_full[frame_index]
199+
frame_tv = v2.CenterCrop(size=(height, width))(frame_full)
200+
assert_frames_equal(frame_center_crop, frame_tv)
201+
202+
def test_center_crop_fails(self):
203+
with pytest.raises(
204+
ValueError,
205+
match=r"must have a \(height, width\) pair for the size",
206+
):
207+
VideoDecoder(
208+
NASA_VIDEO.path,
209+
transforms=[torchcodec.transforms.CenterCrop(size=(100,))],
210+
)
211+
157212
@pytest.mark.parametrize(
158213
"height_scaling_factor, width_scaling_factor",
159214
((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)),
@@ -257,7 +312,7 @@ def test_random_crop_nhwc(
257312
),
258313
),
259314
)
260-
def test_crop_fails(self, error_message, params):
315+
def test_random_crop_fails(self, error_message, params):
261316
with pytest.raises(
262317
ValueError,
263318
match=error_message,

0 commit comments

Comments
 (0)