Skip to content

Commit 3f63c85

Browse files
scottsNicolasHug
authored andcommitted
C++ implementation of crop transform (meta-pytorch#967)
1 parent 9c3c5d2 commit 3f63c85

16 files changed

+467
-198
lines changed

.github/workflows/reference_resources.yaml

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,51 @@ defaults:
1414
shell: bash -l -eo pipefail {0}
1515

1616
jobs:
17+
generate-matrix:
18+
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
19+
with:
20+
package-type: wheel
21+
os: linux
22+
test-infra-repository: pytorch/test-infra
23+
test-infra-ref: main
24+
with-xpu: disable
25+
with-rocm: disable
26+
with-cuda: disable
27+
build-python-only: "disable"
28+
29+
build:
30+
needs: generate-matrix
31+
strategy:
32+
fail-fast: false
33+
name: Build and Upload Linux wheel
34+
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
35+
with:
36+
repository: meta-pytorch/torchcodec
37+
ref: ""
38+
test-infra-repository: pytorch/test-infra
39+
test-infra-ref: main
40+
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
41+
pre-script: packaging/pre_build_script.sh
42+
post-script: packaging/post_build_script.sh
43+
smoke-test-script: packaging/fake_smoke_test.py
44+
package-name: torchcodec
45+
trigger-event: ${{ github.event_name }}
46+
build-platform: "python-build-package"
47+
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 python -m build --wheel -vvv --no-isolation"
48+
1749
test-reference-resource-generation:
50+
needs: build
1851
runs-on: ubuntu-latest
1952
strategy:
2053
fail-fast: false
2154
matrix:
2255
python-version: ['3.10']
2356
ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1']
2457
steps:
58+
- uses: actions/download-artifact@v4
59+
with:
60+
name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64
61+
path: pytorch/torchcodec/dist/
2562
- name: Setup conda env
2663
uses: conda-incubator/setup-miniconda@v2
2764
with:
@@ -43,11 +80,16 @@ jobs:
4380
# Note that we're installing stable - this is for running a script where we're a normal PyTorch
4481
# user, not for building TorhCodec.
4582
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
46-
python -m pip install numpy pillow
83+
python -m pip install numpy pillow pytest
4784
85+
- name: Install torchcodec from the wheel
86+
run: |
87+
wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"`
88+
echo Installing $wheel_path
89+
python -m pip install $wheel_path -vvv
4890
- name: Check out repo
4991
uses: actions/checkout@v3
5092

5193
- name: Run generation reference resources
5294
run: |
53-
python test/generate_reference_resources.py
95+
python -m test.generate_reference_resources

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ FilterGraph::FilterGraph(
130130
TORCH_CHECK(
131131
status >= 0,
132132
"Failed to configure filter graph: ",
133-
getFFMPEGErrorStringFromErrorCode(status));
133+
getFFMPEGErrorStringFromErrorCode(status),
134+
", provided filters: " + filtersContext.filtergraphStr);
134135
}
135136

136137
UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) {

src/torchcodec/_core/Frame.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
namespace facebook::torchcodec {
1010

11+
FrameDims::FrameDims(int height, int width) : height(height), width(width) {
12+
TORCH_CHECK(height > 0, "FrameDims.height must be > 0, got: ", height);
13+
TORCH_CHECK(width > 0, "FrameDims.width must be > 0, got: ", width);
14+
}
15+
1116
FrameBatchOutput::FrameBatchOutput(
1217
int64_t numFrames,
1318
const FrameDims& outputDims,

src/torchcodec/_core/Frame.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct FrameDims {
1919

2020
FrameDims() = default;
2121

22-
FrameDims(int h, int w) : height(h), width(w) {}
22+
FrameDims(int h, int w);
2323
};
2424

2525
// All public video decoding entry points return either a FrameOutput or a

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <sstream>
1313
#include <stdexcept>
1414
#include <string_view>
15+
#include "Metadata.h"
1516
#include "torch/types.h"
1617

1718
namespace facebook::torchcodec {
@@ -527,6 +528,7 @@ void SingleStreamDecoder::addVideoStream(
527528
if (transform->getOutputFrameDims().has_value()) {
528529
resizedOutputDims_ = transform->getOutputFrameDims().value();
529530
}
531+
transform->validate(streamMetadata);
530532

531533
// Note that we are claiming ownership of the transform objects passed in to
532534
// us.

src/torchcodec/_core/Transform.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,31 @@ int ResizeTransform::getSwsFlags() const {
5757
return toSwsInterpolation(interpolationMode_);
5858
}
5959

60+
CropTransform::CropTransform(const FrameDims& dims, int x, int y)
61+
: outputDims_(dims), x_(x), y_(y) {
62+
TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);
63+
TORCH_CHECK(y_ >= 0, "Crop y position must be >= 0, got: ", y_);
64+
}
65+
66+
std::string CropTransform::getFilterGraphCpu() const {
67+
return "crop=" + std::to_string(outputDims_.width) + ":" +
68+
std::to_string(outputDims_.height) + ":" + std::to_string(x_) + ":" +
69+
std::to_string(y_) + ":exact=1";
70+
}
71+
72+
std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
73+
return outputDims_;
74+
}
75+
76+
void CropTransform::validate(const StreamMetadata& streamMetadata) const {
77+
TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
78+
TORCH_CHECK(
79+
x_ + outputDims_.width <= streamMetadata.width,
80+
"Crop x position out of bounds")
81+
TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
82+
TORCH_CHECK(
83+
y_ + outputDims_.height <= streamMetadata.height,
84+
"Crop y position out of bounds");
85+
}
86+
6087
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <optional>
1010
#include <string>
1111
#include "src/torchcodec/_core/Frame.h"
12+
#include "src/torchcodec/_core/Metadata.h"
1213

1314
namespace facebook::torchcodec {
1415

@@ -33,6 +34,16 @@ class Transform {
3334
virtual bool isResize() const {
3435
return false;
3536
}
37+
38+
// The validity of some transforms depends on the characteristics of the
39+
// AVStream they're being applied to. For example, some transforms will
40+
// specify coordinates inside a frame, we need to validate that those are
41+
// within the frame's bounds.
42+
//
43+
// Note that the validation function does not return anything. We expect
44+
// invalid configurations to throw an exception.
45+
virtual void validate(
46+
[[maybe_unused]] const StreamMetadata& streamMetadata) const {}
3647
};
3748

3849
class ResizeTransform : public Transform {
@@ -56,4 +67,18 @@ class ResizeTransform : public Transform {
5667
InterpolationMode interpolationMode_;
5768
};
5869

70+
class CropTransform : public Transform {
71+
public:
72+
CropTransform(const FrameDims& dims, int x, int y);
73+
74+
std::string getFilterGraphCpu() const override;
75+
std::optional<FrameDims> getOutputFrameDims() const override;
76+
void validate(const StreamMetadata& streamMetadata) const override;
77+
78+
private:
79+
FrameDims outputDims_;
80+
int x_;
81+
int y_;
82+
};
83+
5984
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,26 @@ Transform* makeResizeTransform(
214214
return new ResizeTransform(FrameDims(height, width));
215215
}
216216

217+
// Crop transform specs take the form:
218+
//
219+
// "crop, <height>, <width>, <x>, <y>"
220+
//
221+
// Where "crop" is the string literal and <height>, <width>, <x> and <y> are
222+
// positive integers. <x> and <y> are the x and y coordinates of the top left
223+
// corner of the crop. Note that we follow the PyTorch convention of (height,
224+
// width) for specifying image dimensions; FFmpeg uses (width, height).
225+
Transform* makeCropTransform(
226+
const std::vector<std::string>& cropTransformSpec) {
227+
TORCH_CHECK(
228+
cropTransformSpec.size() == 5,
229+
"cropTransformSpec must have 5 elements including its name");
230+
int height = checkedToPositiveInt(cropTransformSpec[1]);
231+
int width = checkedToPositiveInt(cropTransformSpec[2]);
232+
int x = checkedToPositiveInt(cropTransformSpec[3]);
233+
int y = checkedToPositiveInt(cropTransformSpec[4]);
234+
return new CropTransform(FrameDims(height, width), x, y);
235+
}
236+
217237
std::vector<std::string> split(const std::string& str, char delimiter) {
218238
std::vector<std::string> tokens;
219239
std::string token;
@@ -241,6 +261,8 @@ std::vector<Transform*> makeTransforms(const std::string& transformSpecsRaw) {
241261
auto name = transformSpec[0];
242262
if (name == "resize") {
243263
transforms.push_back(makeResizeTransform(transformSpec));
264+
} else if (name == "crop") {
265+
transforms.push_back(makeCropTransform(transformSpec));
244266
} else {
245267
TORCH_CHECK(false, "Invalid transform name: " + name);
246268
}

test/generate_reference_resources.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,15 @@
1212
import torch
1313
from PIL import Image
1414

15+
from .utils import sanitize_filtergraph_expression
16+
1517
# Run this script to update the resources used in unit tests. The resources are all derived
1618
# from source media already checked into the repo.
1719

20+
SCRIPT_DIR = Path(__file__).resolve().parent
21+
TORCHCODEC_PATH = SCRIPT_DIR.parent
22+
RESOURCES_DIR = TORCHCODEC_PATH / "test" / "resources"
23+
1824

1925
def convert_image_to_tensor(image_path):
2026
image_path = Path(image_path)
@@ -31,7 +37,18 @@ def convert_image_to_tensor(image_path):
3137
image_path.unlink()
3238

3339

34-
def get_frame_by_index(video_path, frame, output_path, stream):
40+
def get_frame_by_index(video_path, frame, output_path, stream, filters=None):
41+
# Note that we have an exlicit format conversion to rgb24 in our filtergraph specification,
42+
# which always happens BEFORE any of the filters that we receive as input. We do this to
43+
# ensure that the color conversion happens BEFORE the filters, matching the behavior of the
44+
# torchcodec filtergraph implementation.
45+
#
46+
# Not doing this would result in the color conversion happening AFTER the filters, which
47+
# would result in different color values for the same frame.
48+
filtergraph = f"select='eq(n\\,{frame})',format=rgb24"
49+
if filters is not None:
50+
filtergraph = filtergraph + f",{filters}"
51+
3552
cmd = [
3653
"ffmpeg",
3754
"-y",
@@ -40,11 +57,11 @@ def get_frame_by_index(video_path, frame, output_path, stream):
4057
"-map",
4158
f"0:{stream}",
4259
"-vf",
43-
f"select=eq(n\\,{frame})",
44-
"-vsync",
45-
"vfr",
46-
"-q:v",
47-
"2",
60+
filtergraph,
61+
"-fps_mode",
62+
"passthrough",
63+
"-update",
64+
"1",
4865
output_path,
4966
]
5067
subprocess.run(cmd, check=True)
@@ -65,14 +82,9 @@ def get_frame_by_timestamp(video_path, timestamp, output_path):
6582
subprocess.run(cmd, check=True)
6683

6784

68-
def main():
69-
SCRIPT_DIR = Path(__file__).resolve().parent
70-
TORCHCODEC_PATH = SCRIPT_DIR.parent
71-
RESOURCES_DIR = TORCHCODEC_PATH / "test" / "resources"
85+
def generate_nasa_13013_references():
7286
VIDEO_PATH = RESOURCES_DIR / "nasa_13013.mp4"
7387

74-
# Last generated with ffmpeg version 4.3
75-
#
7688
# Note: The naming scheme used here must match the naming scheme used to load
7789
# tensors in ./utils.py.
7890
STREAMS = [0, 3]
@@ -95,6 +107,16 @@ def main():
95107
get_frame_by_timestamp(VIDEO_PATH, timestamp, output_bmp)
96108
convert_image_to_tensor(output_bmp)
97109

110+
# Extract frames with specific filters. We have tests that assume these exact filters.
111+
FRAMES = [0, 15, 200, 389]
112+
crop_filter = "crop=300:200:50:35:exact=1"
113+
for frame in FRAMES:
114+
output_bmp = f"{VIDEO_PATH}.{sanitize_filtergraph_expression(crop_filter)}.stream3.frame{frame:06d}.bmp"
115+
get_frame_by_index(VIDEO_PATH, frame, output_bmp, stream=3, filters=crop_filter)
116+
convert_image_to_tensor(output_bmp)
117+
118+
119+
def generate_h265_video_references():
98120
# This video was generated by running the following:
99121
# conda install -c conda-forge x265
100122
# ./configure --enable-nonfree --enable-gpl --prefix=$(readlink -f ../bin) --enable-libx265 --enable-rpath --extra-ldflags=-Wl,-rpath=$CONDA_PREFIX/lib --enable-filter=drawtext --enable-libfontconfig --enable-libfreetype --enable-libharfbuzz
@@ -107,6 +129,8 @@ def main():
107129
get_frame_by_index(VIDEO_PATH, frame, output_bmp, stream=0)
108130
convert_image_to_tensor(output_bmp)
109131

132+
133+
def generate_av1_video_references():
110134
# This video was generated by running the following:
111135
# ffmpeg -f lavfi -i testsrc=duration=5:size=640x360:rate=25,format=yuv420p -c:v libaom-av1 -crf 30 -colorspace bt709 -color_primaries bt709 -color_trc bt709 av1_video.mkv
112136
# Note that this video only has 1 stream, at index 0.
@@ -119,5 +143,11 @@ def main():
119143
convert_image_to_tensor(output_bmp)
120144

121145

146+
def main():
147+
generate_nasa_13013_references()
148+
generate_h265_video_references()
149+
generate_av1_video_references()
150+
151+
122152
if __name__ == "__main__":
123153
main()
Binary file not shown.

0 commit comments

Comments
 (0)