Skip to content

Commit f353758

Browse files
authored
Merge branch 'meta-pytorch:main' into cpu-fallback
2 parents e97490e + 4e412b7 commit f353758

File tree

11 files changed

+651
-118
lines changed

11 files changed

+651
-118
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
# TorchCodec
44

55
TorchCodec is a Python library for decoding video and audio data into PyTorch
6-
tensors, on CPU and CUDA GPU. It also supports audio encoding, and video
7-
encoding will come soon! It aims to be fast, easy to use, and well integrated
6+
tensors, on CPU and CUDA GPU. It also supports video and audio encoding on CPU!
7+
It aims to be fast, easy to use, and well integrated
88
into the PyTorch ecosystem. If you want to use PyTorch to train ML models on
99
videos and audio, TorchCodec is how you turn these into data.
1010

@@ -130,7 +130,8 @@ The following table indicates the compatibility between versions of
130130

131131
| `torchcodec` | `torch` | Python |
132132
| ------------------ | ------------------ | ------------------- |
133-
| `main` / `nightly` | `main` / `nightly` | `>=3.10`, `<=3.13` |
133+
| `main` / `nightly` | `main` / `nightly` | `>=3.10`, `<=3.14` |
134+
| `0.9` | `2.9` | `>=3.10`, `<=3.14` |
134135
| `0.8` | `2.9` | `>=3.10`, `<=3.13` |
135136
| `0.7` | `2.8` | `>=3.9`, `<=3.13` |
136137
| `0.6` | `2.8` | `>=3.9`, `<=3.13` |

docs/source/api_ref_transforms.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ For a tutorial, see: TODO_DECODER_TRANSFORMS_TUTORIAL.
1414
:template: dataclass.rst
1515

1616
DecoderTransform
17+
CenterCrop
18+
RandomCrop
1719
Resize

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,14 @@ void SingleStreamDecoder::addVideoStream(
545545

546546
metadataDims_ =
547547
FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
548+
FrameDims currInputDims = metadataDims_;
548549
for (auto& transform : transforms) {
549550
TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
550551
if (transform->getOutputFrameDims().has_value()) {
551552
resizedOutputDims_ = transform->getOutputFrameDims().value();
552553
}
553-
transform->validate(streamMetadata);
554+
transform->validate(currInputDims);
555+
currInputDims = resizedOutputDims_.value_or(metadataDims_);
554556

555557
// Note that we are claiming ownership of the transform objects passed in to
556558
// us.

src/torchcodec/_core/Transform.cpp

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,31 +37,72 @@ std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
3737
return outputDims_;
3838
}
3939

40+
CropTransform::CropTransform(const FrameDims& dims) : outputDims_(dims) {}
41+
4042
CropTransform::CropTransform(const FrameDims& dims, int x, int y)
4143
: outputDims_(dims), x_(x), y_(y) {
4244
TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);
4345
TORCH_CHECK(y_ >= 0, "Crop y position must be >= 0, got: ", y_);
4446
}
4547

4648
std::string CropTransform::getFilterGraphCpu() const {
49+
// For the FFmpeg filter crop, if the x and y coordinates are left
50+
// unspecified, it defaults to a center crop.
51+
std::string coordinates = x_.has_value()
52+
? (":" + std::to_string(x_.value()) + ":" + std::to_string(y_.value()))
53+
: "";
4754
return "crop=" + std::to_string(outputDims_.width) + ":" +
48-
std::to_string(outputDims_.height) + ":" + std::to_string(x_) + ":" +
49-
std::to_string(y_) + ":exact=1";
55+
std::to_string(outputDims_.height) + coordinates + ":exact=1";
5056
}
5157

5258
std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
5359
return outputDims_;
5460
}
5561

56-
void CropTransform::validate(const StreamMetadata& streamMetadata) const {
57-
TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
62+
void CropTransform::validate(const FrameDims& inputDims) const {
63+
TORCH_CHECK(
64+
outputDims_.height <= inputDims.height,
65+
"Crop output height (",
66+
outputDims_.height,
67+
") is greater than input height (",
68+
inputDims.height,
69+
")");
5870
TORCH_CHECK(
59-
x_ + outputDims_.width <= streamMetadata.width,
60-
"Crop x position out of bounds")
61-
TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
71+
outputDims_.width <= inputDims.width,
72+
"Crop output width (",
73+
outputDims_.width,
74+
") is greater than input width (",
75+
inputDims.width,
76+
")");
6277
TORCH_CHECK(
63-
y_ + outputDims_.height <= streamMetadata.height,
64-
"Crop y position out of bounds");
78+
x_.has_value() == y_.has_value(),
79+
"Crop x and y values must be both set or both unset");
80+
if (x_.has_value()) {
81+
TORCH_CHECK(
82+
x_.value() <= inputDims.width,
83+
"Crop x start position, ",
84+
x_.value(),
85+
", out of bounds of input width, ",
86+
inputDims.width);
87+
TORCH_CHECK(
88+
x_.value() + outputDims_.width <= inputDims.width,
89+
"Crop x end position, ",
90+
x_.value() + outputDims_.width,
91+
", out of bounds of input width ",
92+
inputDims.width);
93+
TORCH_CHECK(
94+
y_.value() <= inputDims.height,
95+
"Crop y start position, ",
96+
y_.value(),
97+
", out of bounds of input height, ",
98+
inputDims.height);
99+
TORCH_CHECK(
100+
y_.value() + outputDims_.height <= inputDims.height,
101+
"Crop y end position, ",
102+
y_.value() + outputDims_.height,
103+
", out of bounds of input height ",
104+
inputDims.height);
105+
}
65106
}
66107

67108
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ class Transform {
3636
//
3737
// Note that the validation function does not return anything. We expect
3838
// invalid configurations to throw an exception.
39-
virtual void validate(
40-
[[maybe_unused]] const StreamMetadata& streamMetadata) const {}
39+
virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {}
4140
};
4241

4342
class ResizeTransform : public Transform {
@@ -62,14 +61,17 @@ class CropTransform : public Transform {
6261
public:
6362
CropTransform(const FrameDims& dims, int x, int y);
6463

64+
// Becomes a center crop if x and y are not specified.
65+
CropTransform(const FrameDims& dims);
66+
6567
std::string getFilterGraphCpu() const override;
6668
std::optional<FrameDims> getOutputFrameDims() const override;
67-
void validate(const StreamMetadata& streamMetadata) const override;
69+
void validate(const FrameDims& inputDims) const override;
6870

6971
private:
7072
FrameDims outputDims_;
71-
int x_;
72-
int y_;
73+
std::optional<int> x_;
74+
std::optional<int> y_;
7375
};
7476

7577
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,19 @@ int checkedToPositiveInt(const std::string& str) {
239239
return ret;
240240
}
241241

242+
int checkedToNonNegativeInt(const std::string& str) {
243+
int ret = 0;
244+
try {
245+
ret = std::stoi(str);
246+
} catch (const std::invalid_argument&) {
247+
TORCH_CHECK(false, "String cannot be converted to an int:" + str);
248+
} catch (const std::out_of_range&) {
249+
TORCH_CHECK(false, "String would become integer out of range:" + str);
250+
}
251+
TORCH_CHECK(ret >= 0, "String must be a non-negative integer:" + str);
252+
return ret;
253+
}
254+
242255
// Resize transform specs take the form:
243256
//
244257
// "resize, <height>, <width>"
@@ -270,11 +283,28 @@ Transform* makeCropTransform(
270283
"cropTransformSpec must have 5 elements including its name");
271284
int height = checkedToPositiveInt(cropTransformSpec[1]);
272285
int width = checkedToPositiveInt(cropTransformSpec[2]);
273-
int x = checkedToPositiveInt(cropTransformSpec[3]);
274-
int y = checkedToPositiveInt(cropTransformSpec[4]);
286+
int x = checkedToNonNegativeInt(cropTransformSpec[3]);
287+
int y = checkedToNonNegativeInt(cropTransformSpec[4]);
275288
return new CropTransform(FrameDims(height, width), x, y);
276289
}
277290

291+
// CenterCrop transform specs take the form:
292+
//
293+
// "center_crop, <height>, <width>"
294+
//
295+
// Where "center_crop" is the string literal and <height>, <width> are
296+
// positive integers. Note that we follow the PyTorch convention of (height,
297+
// width) for specifying image dimensions; FFmpeg uses (width, height).
298+
Transform* makeCenterCropTransform(
299+
const std::vector<std::string>& cropTransformSpec) {
300+
TORCH_CHECK(
301+
cropTransformSpec.size() == 3,
302+
"cropTransformSpec must have 3 elements including its name");
303+
int height = checkedToPositiveInt(cropTransformSpec[1]);
304+
int width = checkedToPositiveInt(cropTransformSpec[2]);
305+
return new CropTransform(FrameDims(height, width));
306+
}
307+
278308
std::vector<std::string> split(const std::string& str, char delimiter) {
279309
std::vector<std::string> tokens;
280310
std::string token;
@@ -304,6 +334,8 @@ std::vector<Transform*> makeTransforms(const std::string& transformSpecsRaw) {
304334
transforms.push_back(makeResizeTransform(transformSpec));
305335
} else if (name == "crop") {
306336
transforms.push_back(makeCropTransform(transformSpec));
337+
} else if (name == "center_crop") {
338+
transforms.push_back(makeCenterCropTransform(transformSpec));
307339
} else {
308340
TORCH_CHECK(false, "Invalid transform name: " + name);
309341
}

src/torchcodec/decoders/_video_decoder.py

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numbers
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import List, Literal, Optional, Sequence, Tuple, Union
12+
from typing import Literal, Optional, Sequence, Tuple, Union
1313

1414
import torch
1515
from torch import device as torch_device, nn, Tensor
@@ -20,7 +20,8 @@
2020
create_decoder,
2121
ERROR_REPORTING_INSTRUCTIONS,
2222
)
23-
from torchcodec.transforms import DecoderTransform, Resize
23+
from torchcodec.transforms import DecoderTransform
24+
from torchcodec.transforms._decoder_transforms import _make_transform_specs
2425

2526

2627
@dataclass
@@ -217,7 +218,10 @@ def __init__(
217218
device = str(device)
218219

219220
device_variant = _get_cuda_backend()
220-
transform_specs = _make_transform_specs(transforms)
221+
transform_specs = _make_transform_specs(
222+
transforms,
223+
input_dims=(self.metadata.height, self.metadata.width),
224+
)
221225

222226
core.add_video_stream(
223227
self._decoder,
@@ -523,78 +527,6 @@ def _get_and_validate_stream_metadata(
523527
)
524528

525529

526-
def _convert_to_decoder_transforms(
527-
transforms: Sequence[Union[DecoderTransform, nn.Module]],
528-
) -> List[DecoderTransform]:
529-
"""Convert a sequence of transforms that may contain TorchVision transform
530-
objects into a list of only TorchCodec transform objects.
531-
532-
Args:
533-
transforms: Squence of transform objects. The objects can be one of two
534-
types:
535-
1. torchcodec.transforms.DecoderTransform
536-
2. torchvision.transforms.v2.Transform, but our type annotation
537-
only mentions its base, nn.Module. We don't want to take a
538-
hard dependency on TorchVision.
539-
540-
Returns:
541-
List of DecoderTransform objects.
542-
"""
543-
try:
544-
from torchvision.transforms import v2
545-
546-
tv_available = True
547-
except ImportError:
548-
tv_available = False
549-
550-
converted_transforms: list[DecoderTransform] = []
551-
for transform in transforms:
552-
if not isinstance(transform, DecoderTransform):
553-
if not tv_available:
554-
raise ValueError(
555-
f"The supplied transform, {transform}, is not a TorchCodec "
556-
" DecoderTransform. TorchCodec also accept TorchVision "
557-
"v2 transforms, but TorchVision is not installed."
558-
)
559-
elif isinstance(transform, v2.Resize):
560-
converted_transforms.append(Resize._from_torchvision(transform))
561-
else:
562-
raise ValueError(
563-
f"Unsupported transform: {transform}. Transforms must be "
564-
"either a TorchCodec DecoderTransform or a TorchVision "
565-
"v2 transform."
566-
)
567-
else:
568-
converted_transforms.append(transform)
569-
570-
return converted_transforms
571-
572-
573-
def _make_transform_specs(
574-
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
575-
) -> str:
576-
"""Given a sequence of transforms, turn those into the specification string
577-
the core API expects.
578-
579-
Args:
580-
transforms: Optional sequence of transform objects. The objects can be
581-
one of two types:
582-
1. torchcodec.transforms.DecoderTransform
583-
2. torchvision.transforms.v2.Transform, but our type annotation
584-
only mentions its base, nn.Module. We don't want to take a
585-
hard dependency on TorchVision.
586-
587-
Returns:
588-
String of transforms in the format the core API expects: transform
589-
specifications separate by semicolons.
590-
"""
591-
if transforms is None:
592-
return ""
593-
594-
transforms = _convert_to_decoder_transforms(transforms)
595-
return ";".join([t._make_transform_spec() for t in transforms])
596-
597-
598530
def _read_custom_frame_mappings(
599531
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
600532
) -> tuple[Tensor, Tensor, Tensor]:

src/torchcodec/transforms/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from ._decoder_transforms import DecoderTransform, Resize # noqa
7+
from ._decoder_transforms import ( # noqa
8+
CenterCrop,
9+
DecoderTransform,
10+
RandomCrop,
11+
Resize,
12+
)

0 commit comments

Comments
 (0)