Skip to content

Commit fb06f87

Browse files
committed
Proper frame dims handling in CUDA
1 parent 23ec35f commit fb06f87

File tree

6 files changed

+33
-39
lines changed

6 files changed

+33
-39
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ void CpuDeviceInterface::initialize(
5151
const VideoStreamOptions& videoStreamOptions,
5252
const std::vector<std::unique_ptr<Transform>>& transforms,
5353
const AVRational& timeBase,
54-
[[maybe_unused]] const FrameDims& metadataDims,
5554
const std::optional<FrameDims>& resizedOutputDims) {
5655
videoStreamOptions_ = videoStreamOptions;
5756
timeBase_ = timeBase;
@@ -106,7 +105,7 @@ void CpuDeviceInterface::initialize(
106105
}
107106

108107
ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
109-
const FrameDims& outputDims) {
108+
const FrameDims& outputDims) const {
110109
// swscale requires widths to be multiples of 32:
111110
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
112111
bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0;

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class CpuDeviceInterface : public DeviceInterface {
2828
const VideoStreamOptions& videoStreamOptions,
2929
const std::vector<std::unique_ptr<Transform>>& transforms,
3030
const AVRational& timeBase,
31-
[[maybe_unused]] const FrameDims& metadataDims,
3231
const std::optional<FrameDims>& resizedOutputDims) override;
3332

3433
void convertAVFrameToFrameOutput(
@@ -43,7 +42,7 @@ class CpuDeviceInterface : public DeviceInterface {
4342
torch::Tensor& outputTensor);
4443

4544
ColorConversionLibrary getColorConversionLibrary(
46-
const FrameDims& inputFrameDims);
45+
const FrameDims& inputFrameDims) const;
4746

4847
struct SwsFrameContext {
4948
int inputWidth = 0;

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,12 @@ void CudaDeviceInterface::initialize(
190190
const VideoStreamOptions& videoStreamOptions,
191191
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
192192
const AVRational& timeBase,
193-
const FrameDims& metadataDims,
194193
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {
195194
TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized");
196195
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
197196

198197
videoStreamOptions_ = videoStreamOptions;
199198
timeBase_ = timeBase;
200-
metadataDims_ = metadataDims;
201199

202200
// It is important for pytorch itself to create the cuda context. If ffmpeg
203201
// creates the context it may not be compatible with pytorch.
@@ -269,8 +267,8 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12(
269267
avFrame->height,
270268
frameFormat,
271269
avFrame->sample_aspect_ratio,
272-
metadataDims_.width,
273-
metadataDims_.height,
270+
avFrame->width,
271+
avFrame->height,
274272
outputFormat,
275273
filters.str(),
276274
timeBase_,
@@ -304,15 +302,19 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
304302
UniqueAVFrame& avFrame,
305303
FrameOutput& frameOutput,
306304
std::optional<torch::Tensor> preAllocatedOutputTensor) {
305+
// Note that CUDA does not yet support transforms, so the only possible
306+
// frame dimensions are the raw decoded frame's dimensions.
307+
auto frameDims = FrameDims(avFrame->width, avFrame->height);
308+
307309
if (preAllocatedOutputTensor.has_value()) {
308310
auto shape = preAllocatedOutputTensor.value().sizes();
309311
TORCH_CHECK(
310-
(shape.size() == 3) && (shape[0] == metadataDims_.height) &&
311-
(shape[1] == metadataDims_.width) && (shape[2] == 3),
312+
(shape.size() == 3) && (shape[0] == frameDims.height) &&
313+
(shape[1] == frameDims.width) && (shape[2] == 3),
312314
"Expected tensor of shape ",
313-
metadataDims_.height,
315+
frameDims.height,
314316
"x",
315-
metadataDims_.width,
317+
frameDims.width,
316318
"x3, got ",
317319
shape);
318320
}
@@ -333,34 +335,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
333335
// whatever reason. Typically that happens if the video's encoder isn't
334336
// supported by NVDEC.
335337
//
336-
// In both cases, we have a frame on the CPU, and we need a CPU device to
337-
// handle it. We send the frame back to the CUDA device when we're done.
338-
//
339-
// TODO: Perhaps we should cache cpuInterface?
340-
auto cpuInterface = createDeviceInterface(torch::kCPU);
341-
TORCH_CHECK(
342-
cpuInterface != nullptr, "Failed to create CPU device interface");
343-
cpuInterface->initialize(
344-
nullptr,
345-
VideoStreamOptions(),
346-
{},
347-
timeBase_,
348-
metadataDims_,
349-
std::nullopt);
338+
// In both cases, we have a frame on the CPU. We send the frame back to the
339+
// CUDA device when we're done.
350340

351341
enum AVPixelFormat frameFormat =
352342
static_cast<enum AVPixelFormat>(avFrame->format);
353343

354344
FrameOutput cpuFrameOutput;
355-
356-
if (frameFormat == AV_PIX_FMT_RGB24 &&
357-
avFrame->width == metadataDims_.width &&
358-
avFrame->height == metadataDims_.height) {
359-
// Reason 1 above. The frame is already in the format and dimensions that
360-
// we need, we just need to convert it to a tensor.
345+
if (frameFormat == AV_PIX_FMT_RGB24) {
346+
// Reason 1 above. The frame is already in RGB24, we just need to convert
347+
// it to a tensor.
361348
cpuFrameOutput.data = rgbAVFrameToTensor(avFrame);
362349
} else {
363-
// Reason 2 above. We need to do a full conversion.
350+
// Reason 2 above. We need to do a full conversion which requires an
351+
// actual CPU device.
352+
//
353+
// TODO: Perhaps we should cache cpuInterface?
354+
auto cpuInterface = createDeviceInterface(torch::kCPU);
355+
TORCH_CHECK(
356+
cpuInterface != nullptr, "Failed to create CPU device interface");
357+
cpuInterface->initialize(
358+
nullptr,
359+
VideoStreamOptions(),
360+
{},
361+
timeBase_,
362+
/*resizedOutputDims=*/std::nullopt);
363+
364364
cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
365365
}
366366

@@ -401,7 +401,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
401401
if (preAllocatedOutputTensor.has_value()) {
402402
dst = preAllocatedOutputTensor.value();
403403
} else {
404-
dst = allocateEmptyHWCTensor(metadataDims_, device_);
404+
dst = allocateEmptyHWCTensor(frameDims, device_);
405405
}
406406

407407
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
@@ -440,7 +440,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
440440
"cudaStreamGetFlags failed: ",
441441
cudaGetErrorString(err));
442442

443-
NppiSize oSizeROI = {metadataDims_.width, metadataDims_.height};
443+
NppiSize oSizeROI = {frameDims.width, frameDims.height};
444444
Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]};
445445

446446
NppStatus status;

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ class CudaDeviceInterface : public DeviceInterface {
2626
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>&
2727
transforms,
2828
const AVRational& timeBase,
29-
const FrameDims& metadataOutputDims,
3029
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims)
3130
override;
3231

@@ -44,7 +43,6 @@ class CudaDeviceInterface : public DeviceInterface {
4443

4544
VideoStreamOptions videoStreamOptions_;
4645
AVRational timeBase_;
47-
FrameDims metadataDims_;
4846

4947
UniqueAVBufferRef ctx_;
5048
std::unique_ptr<NppStreamContext> nppCtx_;

src/torchcodec/_core/DeviceInterface.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class DeviceInterface {
3737
const VideoStreamOptions& videoStreamOptions,
3838
const std::vector<std::unique_ptr<Transform>>& transforms,
3939
const AVRational& timeBase,
40-
const FrameDims& metadataDims,
4140
const std::optional<FrameDims>& resizedOutputDims) = 0;
4241

4342
virtual void convertAVFrameToFrameOutput(

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,6 @@ void SingleStreamDecoder::addVideoStream(
521521
videoStreamOptions,
522522
transforms_,
523523
streamInfo.timeBase,
524-
metadataDims_,
525524
resizedOutputDims_);
526525
}
527526

0 commit comments

Comments
 (0)