Skip to content

Commit c59de36

Browse files
committed
Initial C++ implementaiton of transforms
1 parent ad1631d commit c59de36

17 files changed

+413
-298
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ function(make_torchcodec_libraries
9595
SingleStreamDecoder.cpp
9696
Encoder.cpp
9797
ValidationUtils.cpp
98+
Transform.cpp
9899
)
99100

100101
if(ENABLE_CUDA)

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 93 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,74 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4646
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
4747
}
4848

49+
void CpuDeviceInterface::initialize(
50+
[[maybe_unused]] AVCodecContext* codecContext,
51+
const VideoStreamOptions& videoStreamOptions,
52+
const std::vector<std::unique_ptr<Transform>>& transforms,
53+
const AVRational& timeBase,
54+
const FrameDims& outputDims) {
55+
videoStreamOptions_ = videoStreamOptions;
56+
timeBase_ = timeBase;
57+
outputDims_ = outputDims;
58+
59+
// TODO: rationalize comment below with new stuff.
60+
// By default, we want to use swscale for color conversion because it is
61+
// faster. However, it has width requirements, so we may need to fall back
62+
// to filtergraph. We also need to respect what was requested from the
63+
// options; we respect the options unconditionally, so it's possible for
64+
// swscale's width requirements to be violated. We don't expose the ability to
65+
// choose color conversion library publicly; we only use this ability
66+
// internally.
67+
68+
// If any transforms are not swscale compatible, then we can't use swscale.
69+
bool areTransformsSwScaleCompatible = true;
70+
for (const auto& transform : transforms) {
71+
areTransformsSwScaleCompatible =
72+
areTransformsSwScaleCompatible && transform->isSwScaleCompatible();
73+
}
74+
75+
// swscale requires widths to be multiples of 32:
76+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
77+
bool isWidthSwScaleCompatible = (outputDims_.width % 32) == 0;
78+
79+
bool userRequestedSwScale =
80+
videoStreamOptions_.colorConversionLibrary.has_value() &&
81+
videoStreamOptions_.colorConversionLibrary.value() ==
82+
ColorConversionLibrary::SWSCALE;
83+
84+
// Note that we treat the transform limitation differently from the width
85+
// limitation. That is, we consider the transforms being compatible with
86+
// sws_scale as a hard requirement. If the transforms are not compatiable,
87+
// then we will end up not applying the transforms, and that is wrong.
88+
//
89+
// The width requirement, however, is a soft requirement. Even if we don't
90+
// meet it, we let the user override it. We have tests that depend on this
91+
// behavior. Since we don't expose the ability to choose swscale or
92+
// filtergraph in our public API, this is probably okay. It's also the only
93+
// way that we can be certain we are testing one versus the other.
94+
if (areTransformsSwScaleCompatible &&
95+
(userRequestedSwScale || isWidthSwScaleCompatible)) {
96+
colorConversionLibrary_ = ColorConversionLibrary::SWSCALE;
97+
} else {
98+
colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH;
99+
100+
// If we have any transforms, replace filters_ with the filter strings from
101+
// the transforms.
102+
std::stringstream filters;
103+
bool first = true;
104+
for (const auto& transform : transforms) {
105+
if (!first) {
106+
filters << ",";
107+
}
108+
filters << transform->getFilterGraphCpu();
109+
first = false;
110+
}
111+
if (!transforms.empty()) {
112+
filters_ = filters.str();
113+
}
114+
}
115+
}
116+
49117
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
50118
// Callers may pass a pre-allocated tensor, where the output.data tensor will
51119
// be stored. This parameter is honored in any case, but it only leads to a
@@ -56,25 +124,18 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
56124
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
57125
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
58126
void CpuDeviceInterface::convertAVFrameToFrameOutput(
59-
const VideoStreamOptions& videoStreamOptions,
60-
const AVRational& timeBase,
61127
UniqueAVFrame& avFrame,
62128
FrameOutput& frameOutput,
63129
std::optional<torch::Tensor> preAllocatedOutputTensor) {
64-
auto frameDims =
65-
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
66-
int expectedOutputHeight = frameDims.height;
67-
int expectedOutputWidth = frameDims.width;
68-
69130
if (preAllocatedOutputTensor.has_value()) {
70131
auto shape = preAllocatedOutputTensor.value().sizes();
71132
TORCH_CHECK(
72-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
73-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
133+
(shape.size() == 3) && (shape[0] == outputDims_.height) &&
134+
(shape[1] == outputDims_.width) && (shape[2] == 3),
74135
"Expected pre-allocated tensor of shape ",
75-
expectedOutputHeight,
136+
outputDims_.height,
76137
"x",
77-
expectedOutputWidth,
138+
outputDims_.width,
78139
"x3, got ",
79140
shape);
80141
}
@@ -83,25 +144,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
83144
enum AVPixelFormat frameFormat =
84145
static_cast<enum AVPixelFormat>(avFrame->format);
85146

86-
// By default, we want to use swscale for color conversion because it is
87-
// faster. However, it has width requirements, so we may need to fall back
88-
// to filtergraph. We also need to respect what was requested from the
89-
// options; we respect the options unconditionally, so it's possible for
90-
// swscale's width requirements to be violated. We don't expose the ability to
91-
// choose color conversion library publicly; we only use this ability
92-
// internally.
93-
94-
// swscale requires widths to be multiples of 32:
95-
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
96-
// so we fall back to filtergraph if the width is not a multiple of 32.
97-
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
98-
? ColorConversionLibrary::SWSCALE
99-
: ColorConversionLibrary::FILTERGRAPH;
100-
101-
ColorConversionLibrary colorConversionLibrary =
102-
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
103-
104-
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
147+
if (colorConversionLibrary_ == ColorConversionLibrary::SWSCALE) {
105148
// We need to compare the current frame context with our previous frame
106149
// context. If they are different, then we need to re-create our colorspace
107150
// conversion objects. We create our colorspace conversion objects late so
@@ -113,11 +156,11 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
113156
avFrame->width,
114157
avFrame->height,
115158
frameFormat,
116-
expectedOutputWidth,
117-
expectedOutputHeight);
159+
outputDims_.width,
160+
outputDims_.height);
118161

119-
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
120-
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
162+
outputTensor = preAllocatedOutputTensor.value_or(
163+
allocateEmptyHWCTensor(outputDims_, torch::kCPU));
121164

122165
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
123166
createSwsContext(swsFrameContext, avFrame->colorspace);
@@ -129,34 +172,28 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
129172
// the expected height.
130173
// TODO: Can we do the same check for width?
131174
TORCH_CHECK(
132-
resultHeight == expectedOutputHeight,
133-
"resultHeight != expectedOutputHeight: ",
175+
resultHeight == outputDims_.height,
176+
"resultHeight != outputDims_.height: ",
134177
resultHeight,
135178
" != ",
136-
expectedOutputHeight);
179+
outputDims_.height);
137180

138181
frameOutput.data = outputTensor;
139-
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
140-
// See comment above in swscale branch about the filterGraphContext_
141-
// creation. creation
142-
std::stringstream filters;
143-
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144-
filters << ":sws_flags=bilinear";
145-
182+
} else if (colorConversionLibrary_ == ColorConversionLibrary::FILTERGRAPH) {
146183
FiltersContext filtersContext(
147184
avFrame->width,
148185
avFrame->height,
149186
frameFormat,
150187
avFrame->sample_aspect_ratio,
151-
expectedOutputWidth,
152-
expectedOutputHeight,
188+
outputDims_.width,
189+
outputDims_.height,
153190
AV_PIX_FMT_RGB24,
154-
filters.str(),
155-
timeBase);
191+
filters_,
192+
timeBase_);
156193

157194
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
158195
filterGraphContext_ =
159-
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
196+
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
160197
prevFiltersContext_ = std::move(filtersContext);
161198
}
162199
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
@@ -165,12 +202,12 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
165202
// reshaped to its expected dimensions by filtergraph.
166203
auto shape = outputTensor.sizes();
167204
TORCH_CHECK(
168-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
169-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
205+
(shape.size() == 3) && (shape[0] == outputDims_.height) &&
206+
(shape[1] == outputDims_.width) && (shape[2] == 3),
170207
"Expected output tensor of shape ",
171-
expectedOutputHeight,
208+
outputDims_.height,
172209
"x",
173-
expectedOutputWidth,
210+
outputDims_.width,
174211
"x3, got ",
175212
shape);
176213

@@ -186,7 +223,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
186223
TORCH_CHECK(
187224
false,
188225
"Invalid color conversion library: ",
189-
static_cast<int>(colorConversionLibrary));
226+
static_cast<int>(colorConversionLibrary_));
190227
}
191228
}
192229

@@ -214,9 +251,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
214251

215252
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
216253

217-
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
218-
int height = frameDims.height;
219-
int width = frameDims.width;
254+
int height = filteredAVFrame->height;
255+
int width = filteredAVFrame->width;
220256
std::vector<int64_t> shape = {height, width, 3};
221257
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
222258
AVFrame* filteredAVFramePtr = filteredAVFrame.release();

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ class CpuDeviceInterface : public DeviceInterface {
2323
return std::nullopt;
2424
}
2525

26-
void initializeContext(
27-
[[maybe_unused]] AVCodecContext* codecContext) override {}
28-
29-
void convertAVFrameToFrameOutput(
26+
virtual void initialize(
27+
[[maybe_unused]] AVCodecContext* codecContext,
3028
const VideoStreamOptions& videoStreamOptions,
29+
const std::vector<std::unique_ptr<Transform>>& transforms,
3130
const AVRational& timeBase,
31+
const FrameDims& outputDims) override;
32+
33+
void convertAVFrameToFrameOutput(
3234
UniqueAVFrame& avFrame,
3335
FrameOutput& frameOutput,
3436
std::optional<torch::Tensor> preAllocatedOutputTensor =
@@ -64,6 +66,16 @@ class CpuDeviceInterface : public DeviceInterface {
6466
const SwsFrameContext& swsFrameContext,
6567
const enum AVColorSpace colorspace);
6668

69+
VideoStreamOptions videoStreamOptions_;
70+
ColorConversionLibrary colorConversionLibrary_;
71+
AVRational timeBase_;
72+
FrameDims outputDims_;
73+
74+
// The copy filter just copies the input to the output. Computationally, it
75+
// should be a no-op. If we get no user-provided transforms, we will use the
76+
// copy filter.
77+
std::string filters_ = "copy";
78+
6779
// color-conversion fields. Only one of FilterGraphContext and
6880
// UniqueSwsContext should be non-null.
6981
std::unique_ptr<FilterGraph> filterGraphContext_;

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,16 @@ CudaDeviceInterface::~CudaDeviceInterface() {
185185
}
186186
}
187187

188-
void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
188+
void CudaDeviceInterface::initialize(
189+
AVCodecContext* codecContext,
190+
[[maybe_unsued]] const VideoStreamOptions& videoStreamOptions,
191+
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
192+
[[maybe_unused]] const AVRational& timeBase,
193+
const FrameDims& outputDims) {
189194
TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized");
190195

196+
outputDims_ = outputDims;
197+
191198
// It is important for pytorch itself to create the cuda context. If ffmpeg
192199
// creates the context it may not be compatible with pytorch.
193200
// This is a dummy tensor to initialize the cuda context.
@@ -196,12 +203,9 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
196203
ctx_ = getCudaContext(device_);
197204
nppCtx_ = getNppStreamContext(device_);
198205
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
199-
return;
200206
}
201207

202208
void CudaDeviceInterface::convertAVFrameToFrameOutput(
203-
const VideoStreamOptions& videoStreamOptions,
204-
[[maybe_unused]] const AVRational& timeBase,
205209
UniqueAVFrame& avFrame,
206210
FrameOutput& frameOutput,
207211
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -219,11 +223,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
219223

220224
FrameOutput cpuFrameOutput;
221225
cpuInterface->convertAVFrameToFrameOutput(
222-
videoStreamOptions,
223-
timeBase,
224-
avFrame,
225-
cpuFrameOutput,
226-
preAllocatedOutputTensor);
226+
avFrame, cpuFrameOutput, preAllocatedOutputTensor);
227227

228228
frameOutput.data = cpuFrameOutput.data.to(device_);
229229
return;
@@ -253,25 +253,21 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
253253
"If the video is 10bit, we are tracking 10bit support in "
254254
"https://github.com/pytorch/torchcodec/issues/776");
255255

256-
auto frameDims =
257-
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
258-
int height = frameDims.height;
259-
int width = frameDims.width;
260256
torch::Tensor& dst = frameOutput.data;
261257
if (preAllocatedOutputTensor.has_value()) {
262258
dst = preAllocatedOutputTensor.value();
263259
auto shape = dst.sizes();
264260
TORCH_CHECK(
265-
(shape.size() == 3) && (shape[0] == height) && (shape[1] == width) &&
266-
(shape[2] == 3),
261+
(shape.size() == 3) && (shape[0] == outputDims_.height) &&
262+
(shape[1] == outputDims_.width) && (shape[2] == 3),
267263
"Expected tensor of shape ",
268-
height,
264+
outputDims_.height,
269265
"x",
270-
width,
266+
outputDims_.width,
271267
"x3, got ",
272268
shape);
273269
} else {
274-
dst = allocateEmptyHWCTensor(height, width, device_);
270+
dst = allocateEmptyHWCTensor(outputDims_, device_);
275271
}
276272

277273
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
@@ -308,7 +304,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
308304
"cudaStreamGetFlags failed: ",
309305
cudaGetErrorString(err));
310306

311-
NppiSize oSizeROI = {width, height};
307+
NppiSize oSizeROI = {outputDims_.width, outputDims_.height};
312308
Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]};
313309

314310
NppStatus status;

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,25 @@ class CudaDeviceInterface : public DeviceInterface {
1919

2020
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
2121

22-
void initializeContext(AVCodecContext* codecContext) override;
22+
void initialize(
23+
AVCodecContext* codecContext,
24+
[[maybe_unsued]] const VideoStreamOptions& videoStreamOptions,
25+
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>&
26+
transforms,
27+
[[maybe_unused]] const AVRational& timeBase,
28+
const FrameDims& outputDims) override;
2329

2430
void convertAVFrameToFrameOutput(
2531
const VideoStreamOptions& videoStreamOptions,
2632
const AVRational& timeBase,
2733
UniqueAVFrame& avFrame,
34+
const FrameDims& outputDims,
2835
FrameOutput& frameOutput,
2936
std::optional<torch::Tensor> preAllocatedOutputTensor =
3037
std::nullopt) override;
3138

3239
private:
40+
FrameDims outputDims_;
3341
UniqueAVBufferRef ctx_;
3442
std::unique_ptr<NppStreamContext> nppCtx_;
3543
};

0 commit comments

Comments
 (0)