Skip to content

Commit f8844f4

Browse files
committed
Simplify; handle pipelines
1 parent 817b1f8 commit f8844f4

File tree

6 files changed

+100
-40
lines changed

6 files changed

+100
-40
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,14 @@ void SingleStreamDecoder::addVideoStream(
545545

546546
metadataDims_ =
547547
FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
548+
FrameDims currInputDims = metadataDims_;
548549
for (auto& transform : transforms) {
549550
TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
550551
if (transform->getOutputFrameDims().has_value()) {
551552
resizedOutputDims_ = transform->getOutputFrameDims().value();
552553
}
553-
transform->validate(streamMetadata);
554+
transform->validate(currInputDims);
555+
currInputDims = resizedOutputDims_.value_or(metadataDims_);
554556

555557
// Note that we are claiming ownership of the transform objects passed in to
556558
// us.

src/torchcodec/_core/Transform.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,45 @@ std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
5353
return outputDims_;
5454
}
5555

56-
void CropTransform::validate(const StreamMetadata& streamMetadata) const {
57-
TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
56+
void CropTransform::validate(const FrameDims& inputDims) const {
5857
TORCH_CHECK(
59-
x_ + outputDims_.width <= streamMetadata.width,
60-
"Crop x position out of bounds")
61-
TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
58+
outputDims_.height <= inputDims.height,
59+
"Crop output height (",
60+
outputDims_.height,
61+
") is greater than input height (",
62+
inputDims.height,
63+
")");
6264
TORCH_CHECK(
63-
y_ + outputDims_.height <= streamMetadata.height,
64-
"Crop y position out of bounds");
65+
outputDims_.width <= inputDims.width,
66+
"Crop output width (",
67+
outputDims_.width,
68+
") is greater than input width (",
69+
inputDims.width,
70+
")");
71+
TORCH_CHECK(
72+
x_ <= inputDims.width,
73+
"Crop x start position, ",
74+
x_,
75+
", out of bounds of input width, ",
76+
inputDims.width);
77+
TORCH_CHECK(
78+
x_ + outputDims_.width <= inputDims.width,
79+
"Crop x end position, ",
80+
x_ + outputDims_.width,
81+
", out of bounds of input width ",
82+
inputDims.width);
83+
TORCH_CHECK(
84+
y_ <= inputDims.height,
85+
"Crop y start position, ",
86+
y_,
87+
", out of bounds of input height, ",
88+
inputDims.height);
89+
TORCH_CHECK(
90+
y_ + outputDims_.height <= inputDims.height,
91+
"Crop y end position, ",
92+
y_ + outputDims_.height,
93+
", out of bounds of input height ",
94+
inputDims.height);
6595
}
6696

6797
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ class Transform {
3636
//
3737
// Note that the validation function does not return anything. We expect
3838
// invalid configurations to throw an exception.
39-
virtual void validate(
40-
[[maybe_unused]] const StreamMetadata& streamMetadata) const {}
39+
virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {}
4140
};
4241

4342
class ResizeTransform : public Transform {
@@ -64,7 +63,7 @@ class CropTransform : public Transform {
6463

6564
std::string getFilterGraphCpu() const override;
6665
std::optional<FrameDims> getOutputFrameDims() const override;
67-
void validate(const StreamMetadata& streamMetadata) const override;
66+
void validate(const FrameDims& inputDims) const override;
6867

6968
private:
7069
FrameDims outputDims_;

src/torchcodec/decoders/_video_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def _make_transform_specs(
519519
curr_input_dims = input_dims
520520
for transform in transforms:
521521
if isinstance(transform, DecoderTransform):
522-
output_dims = transform._calculate_output_dims(curr_input_dims)
522+
output_dims = transform._get_output_dims()
523523
converted_transforms.append((transform, curr_input_dims))
524524
else:
525525
if not tv_available:
@@ -530,11 +530,11 @@ def _make_transform_specs(
530530
)
531531
elif isinstance(transform, v2.Resize):
532532
tc_transform = Resize._from_torchvision(transform)
533-
output_dims = tc_transform._calculate_output_dims(curr_input_dims)
533+
output_dims = tc_transform._get_output_dims()
534534
converted_transforms.append((tc_transform, curr_input_dims))
535535
elif isinstance(transform, v2.RandomCrop):
536536
tc_transform = RandomCrop._from_torchvision(transform)
537-
output_dims = tc_transform._calculate_output_dims(curr_input_dims)
537+
output_dims = tc_transform._get_output_dims()
538538
converted_transforms.append((tc_transform, curr_input_dims))
539539
else:
540540
raise ValueError(
@@ -543,7 +543,7 @@ def _make_transform_specs(
543543
"v2 transform."
544544
)
545545

546-
curr_input_dims = output_dims
546+
curr_input_dims = output_dims if output_dims is not None else curr_input_dims
547547

548548
return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])
549549

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ def _make_transform_spec(
4343
) -> str:
4444
pass
4545

46-
def _calculate_output_dims(
47-
self, input_dims: Tuple[Optional[int], Optional[int]]
48-
) -> Tuple[Optional[int], Optional[int]]:
49-
return input_dims
46+
# Transforms that change the dimensions of their input frame return a value.
47+
# Transforms that don't return None; they can rely on this default
48+
# implementation.
49+
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
50+
return None
5051

5152

5253
def import_torchvision_transforms_v2() -> ModuleType:
@@ -80,9 +81,7 @@ def _make_transform_spec(
8081
assert len(self.size) == 2
8182
return f"resize, {self.size[0]}, {self.size[1]}"
8283

83-
def _calculate_output_dims(
84-
self, input_dims: Tuple[Optional[int], Optional[int]]
85-
) -> Tuple[Optional[int], Optional[int]]:
84+
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
8685
# TODO: establish this invariant in the constructor during refactor
8786
assert len(self.size) == 2
8887
return (self.size[0], self.size[1])
@@ -173,24 +172,9 @@ def _make_transform_spec(
173172

174173
return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}"
175174

176-
def _calculate_output_dims(
177-
self, input_dims: Tuple[Optional[int], Optional[int]]
178-
) -> Tuple[Optional[int], Optional[int]]:
175+
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
179176
# TODO: establish this invariant in the constructor during refactor
180177
assert len(self.size) == 2
181-
182-
height, width = input_dims
183-
if height is None:
184-
raise ValueError(
185-
"Video metadata has no height. "
186-
"RandomCrop can only be used when input frame dimensions are known."
187-
)
188-
if width is None:
189-
raise ValueError(
190-
"Video metadata has no width. "
191-
"RandomCrop can only be used when input frame dimensions are known."
192-
)
193-
194178
return (self.size[0], self.size[1])
195179

196180
@classmethod

test/test_transform_ops.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,51 @@ def test_crop_fails(self, error_message, params):
257257
transforms=[v2.RandomCrop(**params)],
258258
)
259259

260+
@pytest.mark.parametrize("seed", [0, 314])
261+
def test_random_crop_reusable_objects(self, seed):
262+
torch.manual_seed(seed)
263+
random_crop = torchcodec.transforms.RandomCrop(size=(100, 100))
264+
265+
# Create a spec which causes us to calculate the random crop location.
266+
_ = random_crop._make_transform_spec((1000, 1000))
267+
first_top = random_crop._top
268+
first_left = random_crop._left
269+
270+
# Create a spec again, which should calculate a different random crop
271+
# location.
272+
_ = random_crop._make_transform_spec((1000, 1000))
273+
assert first_top != random_crop._top
274+
assert first_left != random_crop._left
275+
276+
@pytest.mark.parametrize(
277+
"resize, random_crop",
278+
[
279+
(torchcodec.transforms.Resize, torchcodec.transforms.RandomCrop),
280+
(v2.Resize, v2.RandomCrop),
281+
],
282+
)
283+
def test_transform_pipeline(self, resize, random_crop):
284+
decoder = VideoDecoder(
285+
TEST_SRC_2_720P.path,
286+
transforms=[
287+
# resized to bigger than original
288+
resize(size=(2160, 3840)),
289+
# crop to smaller than the resize, but still bigger than original
290+
random_crop(size=(1080, 1920)),
291+
],
292+
)
293+
294+
num_frames = len(decoder)
295+
for frame_index in [
296+
0,
297+
int(num_frames * 0.25),
298+
int(num_frames * 0.5),
299+
int(num_frames * 0.75),
300+
num_frames - 1,
301+
]:
302+
frame = decoder[frame_index]
303+
assert frame.shape == (TEST_SRC_2_720P.get_num_color_channels(), 1080, 1920)
304+
260305
def test_transform_fails(self):
261306
with pytest.raises(
262307
ValueError,
@@ -519,14 +564,14 @@ def test_crop_transform_fails(self):
519564

520565
with pytest.raises(
521566
RuntimeError,
522-
match="x position out of bounds",
567+
match="x start position, 9999, out of bounds",
523568
):
524569
decoder = create_from_file(str(NASA_VIDEO.path))
525570
add_video_stream(decoder, transform_specs="crop, 100, 100, 9999, 100")
526571

527572
with pytest.raises(
528573
RuntimeError,
529-
match="y position out of bounds",
574+
match=r"Crop output height \(999\) is greater than input height \(270\)",
530575
):
531576
decoder = create_from_file(str(NASA_VIDEO.path))
532577
add_video_stream(decoder, transform_specs="crop, 999, 100, 100, 100")

0 commit comments

Comments
 (0)