Skip to content

Commit a15d458

Browse files
committed
More comments, add pytest to reference resources
1 parent 88bc94a commit a15d458

File tree

4 files changed

+26
-11
lines changed

4 files changed

+26
-11
lines changed

.github/workflows/reference_resources.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
# Note that we're installing stable - this is for running a script where we're a normal PyTorch
4444
# user, not for building TorhCodec.
4545
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
46-
python -m pip install numpy pillow
46+
python -m pip install numpy pillow pytest
4747
4848
- name: Check out repo
4949
uses: actions/checkout@v3

src/torchcodec/_core/Transform.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
7575

7676
void CropTransform::validate(const StreamMetadata& streamMetadata) const {
7777
TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
78+
TORCH_CHECK(x_ + outputDims_.width <= streamMetadata.width, "Crop x position out of bounds")
7879
TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
80+
TORCH_CHECK(y_ + outputDims_.height <= streamMetadata.height, "Crop y position out of bounds");
7981
}
8082

8183
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ Transform* makeResizeTransform(
214214

215215
// Crop transform specs take the form:
216216
//
217-
// "crop, <height>, <width> <x>, <y>"
217+
// "crop, <height>, <width>, <x>, <y>"
218218
//
219219
// Where "crop" is the string literal and <height>, <width>, <x> and <y> are
220-
// positive integers. Note that that in this spec, we are following the
221-
// filtergraph convention of (width, height). This makes it easier to compare it
222-
// against actual filtergraph strings.
220+
// positive integers. <x> and <y> are the x and y coordinates of the top left
221+
// corner of the crop. Note that we follow the PyTorch convention of (height,
222+
// width) for specifying image dimensions; FFmpeg uses (width, height).
223223
Transform* makeCropTransform(
224224
const std::vector<std::string>& cropTransformSpec) {
225225
TORCH_CHECK(

test/test_transform_ops.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import contextlib
88
import os
99

10-
os.environ["TORCH_LOGS"] = "output_code"
1110
import json
1211
import subprocess
1312

@@ -210,9 +209,13 @@ def test_resize_transform_fails(self):
210209

211210
def test_crop_transform(self):
212211
# Note that filtergraph accepts dimensions as (w, h) and we accept them as (h, w).
213-
crop_spec = "crop, 200, 300, 50, 35" # h=200, w=300, x=50, y=35
214-
crop_filtergraph = "crop=300:200:50:35:exact=1" # w=300, h=200, x=50, y=35
215-
expected_shape = (3, 200, 300) # channels=3, height=200, width=300
212+
width = 300
213+
height = 200
214+
x = 50
215+
y = 35
216+
crop_spec = f"crop, {height}, {width}, {x}, {y}"
217+
crop_filtergraph = f"crop={width}:{height}:{x}:{y}:exact=1"
218+
expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width)
216219

217220
decoder_crop = create_from_file(str(NASA_VIDEO.path))
218221
add_video_stream(decoder_crop, transform_specs=crop_spec)
@@ -228,7 +231,7 @@ def test_crop_transform(self):
228231

229232
frame_full, *_ = get_frame_at_index(decoder_full, frame_index=frame_index)
230233
frame_tv = v2.functional.crop(
231-
frame_full, top=35, left=50, height=200, width=300
234+
frame_full, top=y, left=x, height=height, width=width
232235
)
233236

234237
assert frame.shape == expected_shape
@@ -239,28 +242,38 @@ def test_crop_transform(self):
239242
assert_frames_equal(frame, frame_ref)
240243

241244
def test_crop_transform_fails(self):
242-
decoder = create_from_file(str(NASA_VIDEO.path))
243245

244246
with pytest.raises(
245247
RuntimeError,
246248
match="must have 5 elements",
247249
):
250+
decoder = create_from_file(str(NASA_VIDEO.path))
248251
add_video_stream(decoder, transform_specs="crop, 100, 100")
249252

250253
with pytest.raises(
251254
RuntimeError,
252255
match="must be a positive integer",
253256
):
257+
decoder = create_from_file(str(NASA_VIDEO.path))
254258
add_video_stream(decoder, transform_specs="crop, -10, 100, 100, 100")
255259

256260
with pytest.raises(
257261
RuntimeError,
258262
match="cannot be converted to an int",
259263
):
264+
decoder = create_from_file(str(NASA_VIDEO.path))
260265
add_video_stream(decoder, transform_specs="crop, 100, 100, blah, 100")
261266

262267
with pytest.raises(
263268
RuntimeError,
264269
match="x position out of bounds",
265270
):
271+
decoder = create_from_file(str(NASA_VIDEO.path))
266272
add_video_stream(decoder, transform_specs="crop, 100, 100, 9999, 100")
273+
274+
with pytest.raises(
275+
RuntimeError,
276+
match="y position out of bounds",
277+
):
278+
decoder = create_from_file(str(NASA_VIDEO.path))
279+
add_video_stream(decoder, transform_specs="crop, 999, 100, 100, 100")

0 commit comments

Comments
 (0)