Skip to content

Commit 71fac63

Browse files
author
pytorchbot
committed
2025-12-04 nightly release (38fa96c)
1 parent 7599afd commit 71fac63

File tree

8 files changed

+451
-71
lines changed

8 files changed

+451
-71
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,14 @@ void SingleStreamDecoder::addVideoStream(
545545

546546
metadataDims_ =
547547
FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
548+
FrameDims currInputDims = metadataDims_;
548549
for (auto& transform : transforms) {
549550
TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
550551
if (transform->getOutputFrameDims().has_value()) {
551552
resizedOutputDims_ = transform->getOutputFrameDims().value();
552553
}
553-
transform->validate(streamMetadata);
554+
transform->validate(currInputDims);
555+
currInputDims = resizedOutputDims_.value_or(metadataDims_);
554556

555557
// Note that we are claiming ownership of the transform objects passed in to
556558
// us.

src/torchcodec/_core/Transform.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,45 @@ std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
5353
return outputDims_;
5454
}
5555

56-
void CropTransform::validate(const StreamMetadata& streamMetadata) const {
57-
TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
56+
void CropTransform::validate(const FrameDims& inputDims) const {
5857
TORCH_CHECK(
59-
x_ + outputDims_.width <= streamMetadata.width,
60-
"Crop x position out of bounds")
61-
TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
58+
outputDims_.height <= inputDims.height,
59+
"Crop output height (",
60+
outputDims_.height,
61+
") is greater than input height (",
62+
inputDims.height,
63+
")");
6264
TORCH_CHECK(
63-
y_ + outputDims_.height <= streamMetadata.height,
64-
"Crop y position out of bounds");
65+
outputDims_.width <= inputDims.width,
66+
"Crop output width (",
67+
outputDims_.width,
68+
") is greater than input width (",
69+
inputDims.width,
70+
")");
71+
TORCH_CHECK(
72+
x_ <= inputDims.width,
73+
"Crop x start position, ",
74+
x_,
75+
", out of bounds of input width, ",
76+
inputDims.width);
77+
TORCH_CHECK(
78+
x_ + outputDims_.width <= inputDims.width,
79+
"Crop x end position, ",
80+
x_ + outputDims_.width,
81+
", out of bounds of input width ",
82+
inputDims.width);
83+
TORCH_CHECK(
84+
y_ <= inputDims.height,
85+
"Crop y start position, ",
86+
y_,
87+
", out of bounds of input height, ",
88+
inputDims.height);
89+
TORCH_CHECK(
90+
y_ + outputDims_.height <= inputDims.height,
91+
"Crop y end position, ",
92+
y_ + outputDims_.height,
93+
", out of bounds of input height ",
94+
inputDims.height);
6595
}
6696

6797
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ class Transform {
3636
//
3737
// Note that the validation function does not return anything. We expect
3838
// invalid configurations to throw an exception.
39-
virtual void validate(
40-
[[maybe_unused]] const StreamMetadata& streamMetadata) const {}
39+
virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {}
4140
};
4241

4342
class ResizeTransform : public Transform {
@@ -64,7 +63,7 @@ class CropTransform : public Transform {
6463

6564
std::string getFilterGraphCpu() const override;
6665
std::optional<FrameDims> getOutputFrameDims() const override;
67-
void validate(const StreamMetadata& streamMetadata) const override;
66+
void validate(const FrameDims& inputDims) const override;
6867

6968
private:
7069
FrameDims outputDims_;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,19 @@ int checkedToPositiveInt(const std::string& str) {
239239
return ret;
240240
}
241241

242+
int checkedToNonNegativeInt(const std::string& str) {
243+
int ret = 0;
244+
try {
245+
ret = std::stoi(str);
246+
} catch (const std::invalid_argument&) {
247+
TORCH_CHECK(false, "String cannot be converted to an int:" + str);
248+
} catch (const std::out_of_range&) {
249+
TORCH_CHECK(false, "String would become integer out of range:" + str);
250+
}
251+
TORCH_CHECK(ret >= 0, "String must be a non-negative integer:" + str);
252+
return ret;
253+
}
254+
242255
// Resize transform specs take the form:
243256
//
244257
// "resize, <height>, <width>"
@@ -270,8 +283,8 @@ Transform* makeCropTransform(
270283
"cropTransformSpec must have 5 elements including its name");
271284
int height = checkedToPositiveInt(cropTransformSpec[1]);
272285
int width = checkedToPositiveInt(cropTransformSpec[2]);
273-
int x = checkedToPositiveInt(cropTransformSpec[3]);
274-
int y = checkedToPositiveInt(cropTransformSpec[4]);
286+
int x = checkedToNonNegativeInt(cropTransformSpec[3]);
287+
int y = checkedToNonNegativeInt(cropTransformSpec[4]);
275288
return new CropTransform(FrameDims(height, width), x, y);
276289
}
277290

src/torchcodec/decoders/_video_decoder.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import numbers
1010
from pathlib import Path
11-
from typing import List, Literal, Optional, Sequence, Tuple, Union
11+
from typing import Literal, Optional, Sequence, Tuple, Union
1212

1313
import torch
1414
from torch import device as torch_device, nn, Tensor
@@ -19,7 +19,7 @@
1919
create_decoder,
2020
ERROR_REPORTING_INSTRUCTIONS,
2121
)
22-
from torchcodec.transforms import DecoderTransform, Resize
22+
from torchcodec.transforms import DecoderTransform, RandomCrop, Resize
2323

2424

2525
class VideoDecoder:
@@ -167,7 +167,10 @@ def __init__(
167167
device = str(device)
168168

169169
device_variant = _get_cuda_backend()
170-
transform_specs = _make_transform_specs(transforms)
170+
transform_specs = _make_transform_specs(
171+
transforms,
172+
input_dims=(self.metadata.height, self.metadata.width),
173+
)
171174

172175
core.add_video_stream(
173176
self._decoder,
@@ -448,76 +451,100 @@ def _get_and_validate_stream_metadata(
448451
)
449452

450453

451-
def _convert_to_decoder_transforms(
452-
transforms: Sequence[Union[DecoderTransform, nn.Module]],
453-
) -> List[DecoderTransform]:
454-
"""Convert a sequence of transforms that may contain TorchVision transform
455-
objects into a list of only TorchCodec transform objects.
454+
def _make_transform_specs(
455+
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
456+
input_dims: Tuple[Optional[int], Optional[int]],
457+
) -> str:
458+
"""Given a sequence of transforms, turn those into the specification string
459+
the core API expects.
456460
457461
Args:
458-
transforms: Squence of transform objects. The objects can be one of two
459-
types:
462+
transforms: Optional sequence of transform objects. The objects can be
463+
one of two types:
460464
1. torchcodec.transforms.DecoderTransform
461465
2. torchvision.transforms.v2.Transform, but our type annotation
462466
only mentions its base, nn.Module. We don't want to take a
463467
hard dependency on TorchVision.
468+
input_dims: Optional (height, width) pair. Note that only some
469+
transforms need to know the dimensions. If the user provides
470+
transforms that don't need to know the dimensions, and that metadata
471+
is missing, everything should still work. That means we assert their
472+
existence as late as possible.
464473
465474
Returns:
466-
List of DecoderTransform objects.
475+
String of transforms in the format the core API expects: transform
476+
specifications separate by semicolons.
467477
"""
478+
if transforms is None:
479+
return ""
480+
468481
try:
469482
from torchvision.transforms import v2
470483

471484
tv_available = True
472485
except ImportError:
473486
tv_available = False
474487

475-
converted_transforms: list[DecoderTransform] = []
488+
# The following loop accomplishes two tasks:
489+
#
490+
# 1. Converts the transform to a DecoderTransform, if necessary. We
491+
# accept TorchVision transform objects and they must be converted
492+
# to their matching DecoderTransform.
493+
# 2. Calculates what the input dimensions are to each transform.
494+
#
495+
# The order in our transforms list is semantically meaningful, as we
496+
# actually have a pipeline where the output of one transform is the input to
497+
# the next. For example, if we have the transforms list [A, B, C, D], then
498+
# we should understand that as:
499+
#
500+
# A -> B -> C -> D
501+
#
502+
# Where the frame produced by A is the input to B, the frame produced by B
503+
# is the input to C, etc. This particularly matters for frame dimensions.
504+
# Transforms can both:
505+
#
506+
# 1. Produce frames with arbitrary dimensions.
507+
# 2. Rely on their input frame's dimensions to calculate ahead-of-time
508+
# what their runtime behavior will be.
509+
#
510+
# The consequence of the above facts is that we need to statically track
511+
# frame dimensions in the pipeline while we pre-process it. The input
512+
# frame's dimensions to A, our first transform, is always what we know from
513+
# our metadata. For each transform, we always calculate its output
514+
# dimensions from its input dimensions. We store these with the converted
515+
# transform, to be all used together when we generate the specs.
516+
converted_transforms: list[
517+
Tuple[
518+
DecoderTransform,
519+
# A (height, width) pair where the values may be missing.
520+
Tuple[Optional[int], Optional[int]],
521+
]
522+
] = []
523+
curr_input_dims = input_dims
476524
for transform in transforms:
477525
if not isinstance(transform, DecoderTransform):
478526
if not tv_available:
479527
raise ValueError(
480528
f"The supplied transform, {transform}, is not a TorchCodec "
481-
" DecoderTransform. TorchCodec also accept TorchVision "
529+
" DecoderTransform. TorchCodec also accepts TorchVision "
482530
"v2 transforms, but TorchVision is not installed."
483531
)
484532
elif isinstance(transform, v2.Resize):
485-
converted_transforms.append(Resize._from_torchvision(transform))
533+
transform = Resize._from_torchvision(transform)
534+
elif isinstance(transform, v2.RandomCrop):
535+
transform = RandomCrop._from_torchvision(transform)
486536
else:
487537
raise ValueError(
488538
f"Unsupported transform: {transform}. Transforms must be "
489539
"either a TorchCodec DecoderTransform or a TorchVision "
490540
"v2 transform."
491541
)
492-
else:
493-
converted_transforms.append(transform)
494-
495-
return converted_transforms
496542

543+
converted_transforms.append((transform, curr_input_dims))
544+
output_dims = transform._get_output_dims()
545+
curr_input_dims = output_dims if output_dims is not None else curr_input_dims
497546

498-
def _make_transform_specs(
499-
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
500-
) -> str:
501-
"""Given a sequence of transforms, turn those into the specification string
502-
the core API expects.
503-
504-
Args:
505-
transforms: Optional sequence of transform objects. The objects can be
506-
one of two types:
507-
1. torchcodec.transforms.DecoderTransform
508-
2. torchvision.transforms.v2.Transform, but our type annotation
509-
only mentions its base, nn.Module. We don't want to take a
510-
hard dependency on TorchVision.
511-
512-
Returns:
513-
String of transforms in the format the core API expects: transform
514-
specifications separate by semicolons.
515-
"""
516-
if transforms is None:
517-
return ""
518-
519-
transforms = _convert_to_decoder_transforms(transforms)
520-
return ";".join([t._make_transform_spec() for t in transforms])
547+
return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])
521548

522549

523550
def _read_custom_frame_mappings(

src/torchcodec/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from ._decoder_transforms import DecoderTransform, Resize # noqa
7+
from ._decoder_transforms import DecoderTransform, RandomCrop, Resize # noqa

0 commit comments

Comments
 (0)