Skip to content

Commit 7e43313

Browse files
committed
Docstrings, better error checking, better testing
1 parent fd8f7a5 commit 7e43313

File tree

3 files changed

+125
-12
lines changed

3 files changed

+125
-12
lines changed

src/torchcodec/_core/custom_ops.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,19 @@ int checkedToPositiveInt(const std::string& str) {
239239
return ret;
240240
}
241241

242+
int checkedToNonNegativeInt(const std::string& str) {
243+
int ret = 0;
244+
try {
245+
ret = std::stoi(str);
246+
} catch (const std::invalid_argument&) {
247+
TORCH_CHECK(false, "String cannot be converted to an int:" + str);
248+
} catch (const std::out_of_range&) {
249+
TORCH_CHECK(false, "String would become integer out of range:" + str);
250+
}
251+
TORCH_CHECK(ret >= 0, "String must be a non-negative integer:" + str);
252+
return ret;
253+
}
254+
242255
// Resize transform specs take the form:
243256
//
244257
// "resize, <height>, <width>"
@@ -270,8 +283,8 @@ Transform* makeCropTransform(
270283
"cropTransformSpec must have 5 elements including its name");
271284
int height = checkedToPositiveInt(cropTransformSpec[1]);
272285
int width = checkedToPositiveInt(cropTransformSpec[2]);
273-
int x = checkedToPositiveInt(cropTransformSpec[3]);
274-
int y = checkedToPositiveInt(cropTransformSpec[4]);
286+
int x = checkedToNonNegativeInt(cropTransformSpec[3]);
287+
int y = checkedToNonNegativeInt(cropTransformSpec[4]);
275288
return new CropTransform(FrameDims(height, width), x, y);
276289
}
277290

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class DecoderTransform(ABC):
2323
decoded frames and applying the same kind of transform.
2424
2525
Most ``DecoderTransform`` objects have a complementary transform in TorchVision,
26-
specificially in `torchvision.transforms.v2 <https://docs.pytorch.org/vision/stable/transforms.html>`_. For such transforms, we
27-
ensure that:
26+
specificially in `torchvision.transforms.v2 <https://docs.pytorch.org/vision/stable/transforms.html>`_.
27+
For such transforms, we ensure that:
2828
2929
1. The names are the same.
3030
2. Default behaviors are the same.
@@ -74,7 +74,7 @@ def _make_transform_spec(self) -> str:
7474
return f"resize, {self.size[0]}, {self.size[1]}"
7575

7676
def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]:
77-
return self.size
77+
return (*self.size,)
7878

7979
@classmethod
8080
def _from_torchvision(cls, resize_tv: nn.Module):
@@ -102,20 +102,51 @@ def _from_torchvision(cls, resize_tv: nn.Module):
102102

103103
@dataclass
104104
class RandomCrop(DecoderTransform):
105+
"""Crop the decoded frame to a given size at a random location in the frame.
106+
107+
Complementary TorchVision transform: :class:`~torchvision.transforms.v2.RandomCrop`.
108+
Padding of all kinds is disabled. The random location within the frame is
109+
determined during the initialization of the
110+
:class:~`torchcodec.decoders.VideoDecoder` object that owns this transform.
111+
As a consequence, each decoded frame in the video will be cropped at the
112+
same location. Videos with variable resolution may result in undefined
113+
behavior.
114+
115+
Args:
116+
size: (sequence of int): Desired output size. Must be a sequence of
117+
the form (height, width).
118+
"""
105119

106120
size: Sequence[int]
107121
_top: Optional[int] = None
108122
_left: Optional[int] = None
109123
_input_dims: Optional[Tuple[int, int]] = None
110124

111125
def _make_transform_spec(self) -> str:
112-
assert len(self.size) == 2
126+
if len(self.size) != 2:
127+
raise ValueError(
128+
f"RandomCrop's size must be a sequence of length 2, got {self.size}. "
129+
"This should never happen, please report a bug."
130+
)
131+
113132
if self._top is None or self._left is None:
114-
assert self._input_dims is not None
133+
# TODO: It would be very strange if only ONE of those is None. But should we
134+
# make it an error? We can continue, but it would probably mean
135+
# something bad happened. Dear reviewer, please register an opinion here:
136+
if self._input_dims is None:
137+
raise ValueError(
138+
"RandomCrop's input_dims must be set before calling _make_transform_spec(). "
139+
"This should never happen, please report a bug."
140+
)
115141
if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]:
116142
raise ValueError(
117143
f"Input dimensions {input_dims} are smaller than the crop size {self.size}."
118144
)
145+
146+
# Note: This logic must match the logic in
147+
# torchvision.transforms.v2.RandomCrop.make_params(). Given
148+
# the same seed, they should get the same result. This is an
149+
# API guarantee with our users.
119150
self._top = torch.randint(
120151
0, self._input_dims[0] - self.size[0] + 1, size=()
121152
)
@@ -144,17 +175,16 @@ def _from_torchvision(cls, random_crop_tv: nn.Module, input_dims: Tuple[int, int
144175
"TorchVision RandomCrop transform must not specify pad_if_needed."
145176
)
146177
if random_crop_tv.fill != 0:
147-
raise ValueError("TorchVision RandomCrop must specify fill of 0.")
178+
raise ValueError("TorchVision RandomCrop fill must be 0.")
148179
if random_crop_tv.padding_mode != "constant":
149-
raise ValueError(
150-
"TorchVision RandomCrop must specify padding_mode of constant."
151-
)
180+
raise ValueError("TorchVision RandomCrop padding_mode must be constant.")
152181
if len(random_crop_tv.size) != 2:
153182
raise ValueError(
154183
"TorchVision RandcomCrop transform must have a (height, width) "
155184
f"pair for the size, got {random_crop_tv.size}."
156185
)
157186
params = random_crop_tv.make_params(
187+
# TODO: deal with NCHW versus NHWC; video decoder knows
158188
torch.empty(size=(3, *input_dims), dtype=torch.uint8)
159189
)
160190
assert random_crop_tv.size == (params["height"], params["width"])

test/test_transform_ops.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_resize_fails(self):
147147

148148
@pytest.mark.parametrize(
149149
"height_scaling_factor, width_scaling_factor",
150-
((0.5, 0.5), (0.25, 0.1)),
150+
((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.25, 0.25)),
151151
)
152152
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
153153
def test_random_crop_torchvision(
@@ -156,6 +156,9 @@ def test_random_crop_torchvision(
156156
height = int(video.get_height() * height_scaling_factor)
157157
width = int(video.get_width() * width_scaling_factor)
158158

159+
# We want both kinds of RandomCrop objects to get arrive at the same
160+
# locations to crop, so we need to make sure they get the same random
161+
# seed.
159162
torch.manual_seed(0)
160163
tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width))
161164
decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop])
@@ -188,6 +191,73 @@ def test_random_crop_torchvision(
188191
expected_shape = (video.get_num_color_channels(), height, width)
189192
assert frame_random_crop_tv.shape == expected_shape
190193

194+
frame_full = decoder_full[frame_index]
195+
frame_tv = v2.functional.crop(
196+
frame_full,
197+
top=tc_random_crop._top,
198+
left=tc_random_crop._left,
199+
height=tc_random_crop.size[0],
200+
width=tc_random_crop.size[1],
201+
)
202+
assert_frames_equal(frame_random_crop, frame_tv)
203+
204+
def test_crop_fails(self):
205+
with pytest.raises(
206+
ValueError,
207+
match="must not specify padding",
208+
):
209+
VideoDecoder(
210+
NASA_VIDEO.path,
211+
transforms=[
212+
v2.RandomCrop(
213+
size=(100, 100),
214+
padding=255,
215+
)
216+
],
217+
)
218+
219+
with pytest.raises(
220+
ValueError,
221+
match="must not specify pad_if_needed",
222+
):
223+
VideoDecoder(
224+
NASA_VIDEO.path,
225+
transforms=[
226+
v2.RandomCrop(
227+
size=(100, 100),
228+
pad_if_needed=True,
229+
)
230+
],
231+
)
232+
233+
with pytest.raises(
234+
ValueError,
235+
match="fill must be 0",
236+
):
237+
VideoDecoder(
238+
NASA_VIDEO.path,
239+
transforms=[
240+
v2.RandomCrop(
241+
size=(100, 100),
242+
fill=255,
243+
)
244+
],
245+
)
246+
247+
with pytest.raises(
248+
ValueError,
249+
match="padding_mode must be constant",
250+
):
251+
VideoDecoder(
252+
NASA_VIDEO.path,
253+
transforms=[
254+
v2.RandomCrop(
255+
size=(100, 100),
256+
padding_mode="edge",
257+
)
258+
],
259+
)
260+
191261
def test_transform_fails(self):
192262
with pytest.raises(
193263
ValueError,

0 commit comments

Comments
 (0)