Skip to content

Commit ee3b9b7

Browse files
committed
Apply reviewer suggestions
1 parent 1a07828 commit ee3b9b7

File tree

8 files changed

+39
-43
lines changed

8 files changed

+39
-43
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
161161
// FrameBatchOutputs based on the the stream metadata. But single-frame APIs
162162
// can still work in such situations, so they should.
163163
auto outputDims =
164-
resizedOutputDims_.value_or(FrameDims(avFrame->width, avFrame->height));
164+
resizedOutputDims_.value_or(FrameDims(avFrame->height, avFrame->width));
165165

166166
if (preAllocatedOutputTensor.has_value()) {
167167
auto shape = preAllocatedOutputTensor.value().sizes();

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,14 @@ class CpuDeviceInterface : public DeviceInterface {
9696
UniqueSwsContext swsContext_;
9797
SwsFrameContext prevSwsFrameContext_;
9898

99-
// The filter we supply to filterGraph_, if it is used. The copy filter just
100-
// copies the input to the output. Computationally, it should be a no-op. If
101-
// we get no user-provided transforms, we will use the copy filter. Otherwise,
102-
// we will construct the string from the transforms.
99+
// The filter we supply to filterGraph_, if it is used. The default is the
100+
// copy filter, which just copies the input to the output. Computationally, it
101+
// should be a no-op. If we get no user-provided transforms, we will use the
102+
// copy filter. Otherwise, we will construct the string from the transforms.
103+
//
104+
// Note that even if we only use the copy filter, we still get the desired
105+
// colorspace conversion. We construct the filtergraph with its output sink
106+
// set to RGB24.
103107
std::string filters_ = "copy";
104108

105109
// The flags we supply to swsContext_, if it used. The flags control the

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
174174
TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!");
175175
TORCH_CHECK(
176176
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
177+
178+
// It is important for pytorch itself to create the cuda context. If ffmpeg
179+
// creates the context it may not be compatible with pytorch.
180+
// This is a dummy tensor to initialize the cuda context.
181+
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
182+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
183+
ctx_ = getCudaContext(device_);
184+
nppCtx_ = getNppStreamContext(device_);
177185
}
178186

179187
CudaDeviceInterface::~CudaDeviceInterface() {
@@ -191,20 +199,12 @@ void CudaDeviceInterface::initialize(
191199
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
192200
const AVRational& timeBase,
193201
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {
194-
TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized");
202+
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
195203
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
196204

205+
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
197206
videoStreamOptions_ = videoStreamOptions;
198207
timeBase_ = timeBase;
199-
200-
// It is important for pytorch itself to create the cuda context. If ffmpeg
201-
// creates the context it may not be compatible with pytorch.
202-
// This is a dummy tensor to initialize the cuda context.
203-
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
204-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
205-
ctx_ = getCudaContext(device_);
206-
nppCtx_ = getNppStreamContext(device_);
207-
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
208208
}
209209

210210
UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12(
@@ -304,7 +304,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
304304
std::optional<torch::Tensor> preAllocatedOutputTensor) {
305305
// Note that CUDA does not yet support transforms, so the only possible
306306
// frame dimensions are the raw decoded frame's dimensions.
307-
auto frameDims = FrameDims(avFrame->width, avFrame->height);
307+
auto frameDims = FrameDims(avFrame->height, avFrame->width);
308308

309309
if (preAllocatedOutputTensor.has_value()) {
310310
auto shape = preAllocatedOutputTensor.value().sizes();
@@ -379,14 +379,15 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
379379

380380
// Above we checked that the AVFrame was on GPU, but that's not enough, we
381381
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
382-
// because this is what the NPP color conversion routines expect.
382+
// because this is what the NPP color conversion routines expect. This SHOULD
383+
// be enforced by our call to maybeConvertAVFrameToNV12() above.
384+
auto hwFramesCtx =
385+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
383386
TORCH_CHECK(
384-
avFrame->hw_frames_ctx != nullptr,
387+
hwFramesCtx != nullptr,
385388
"The AVFrame does not have a hw_frames_ctx. "
386389
"That's unexpected, please report this to the TorchCodec repo.");
387390

388-
auto hwFramesCtx =
389-
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
390391
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
391392

392393
TORCH_CHECK(

src/torchcodec/_core/Frame.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
namespace facebook::torchcodec {
1515

1616
struct FrameDims {
17-
int width = 0;
1817
int height = 0;
18+
int width = 0;
1919

2020
FrameDims() = default;
2121

22-
FrameDims(int w, int h) : width(w), height(h) {}
22+
FrameDims(int h, int w) : height(h), width(w) {}
2323
};
2424

2525
// All public video decoding entry points return either a FrameOutput or a

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ void SingleStreamDecoder::addVideoStream(
504504
}
505505

506506
metadataDims_ =
507-
FrameDims(streamMetadata.width.value(), streamMetadata.height.value());
507+
FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
508508
for (auto& transform : transforms) {
509509
TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
510510
if (transform->getOutputFrameDims().has_value()) {

src/torchcodec/_core/Transform.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ std::string toFilterGraphInterpolation(
1717
switch (mode) {
1818
case ResizeTransform::InterpolationMode::BILINEAR:
1919
return "bilinear";
20-
case ResizeTransform::InterpolationMode::BICUBIC:
21-
return "bicubic";
22-
case ResizeTransform::InterpolationMode::NEAREST:
23-
return "nearest";
2420
default:
2521
TORCH_CHECK(
2622
false,
@@ -33,10 +29,6 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
3329
switch (mode) {
3430
case ResizeTransform::InterpolationMode::BILINEAR:
3531
return SWS_BILINEAR;
36-
case ResizeTransform::InterpolationMode::BICUBIC:
37-
return SWS_BICUBIC;
38-
case ResizeTransform::InterpolationMode::NEAREST:
39-
return SWS_POINT;
4032
default:
4133
TORCH_CHECK(
4234
false,
@@ -48,12 +40,13 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
4840
} // namespace
4941

5042
std::string ResizeTransform::getFilterGraphCpu() const {
51-
return "scale=" + std::to_string(width_) + ":" + std::to_string(height_) +
43+
return "scale=" + std::to_string(outputDims_.width) + ":" +
44+
std::to_string(outputDims_.height) +
5245
":sws_flags=" + toFilterGraphInterpolation(interpolationMode_);
5346
}
5447

5548
std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
56-
return FrameDims(width_, height_);
49+
return outputDims_;
5750
}
5851

5952
bool ResizeTransform::isResize() const {

src/torchcodec/_core/Transform.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,13 @@ class Transform {
3737

3838
class ResizeTransform : public Transform {
3939
public:
40-
enum class InterpolationMode { BILINEAR, BICUBIC, NEAREST };
40+
enum class InterpolationMode { BILINEAR };
4141

42-
ResizeTransform(int width, int height)
43-
: width_(width),
44-
height_(height),
45-
interpolationMode_(InterpolationMode::BILINEAR) {}
42+
ResizeTransform(const FrameDims& dims)
43+
: outputDims_(dims), interpolationMode_(InterpolationMode::BILINEAR) {}
4644

47-
ResizeTransform(int width, int height, InterpolationMode interpolationMode)
48-
: width_(width), height_(height), interpolationMode_(interpolationMode) {}
45+
ResizeTransform(const FrameDims& dims, InterpolationMode interpolationMode)
46+
: outputDims_(dims), interpolationMode_(interpolationMode) {}
4947

5048
std::string getFilterGraphCpu() const override;
5149
std::optional<FrameDims> getOutputFrameDims() const override;
@@ -54,8 +52,7 @@ class ResizeTransform : public Transform {
5452
int getSwsFlags() const;
5553

5654
private:
57-
int width_;
58-
int height_;
55+
FrameDims outputDims_;
5956
InterpolationMode interpolationMode_;
6057
};
6158

src/torchcodec/_core/custom_ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ void _add_video_stream(
274274
"width and height must both be set or unset.");
275275
std::vector<Transform*> transforms;
276276
if (width.has_value()) {
277-
transforms.push_back(new ResizeTransform(width.value(), height.value()));
277+
transforms.push_back(
278+
new ResizeTransform(FrameDims(height.value(), width.value())));
278279
width.reset();
279280
height.reset();
280281
}

0 commit comments

Comments
 (0)