Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/torchcodec/_core/BetaCudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ void BetaCudaDeviceInterface::initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx,
[[maybe_unused]] const SharedAVCodecContext& codecContext) {
STD_TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
rotation_ = rotationFromDegrees(getRotationFromStream(avStream));
if (!nvcuvidAvailable_ || !nativeNVDECSupport(device_, codecContext)) {
cpuFallback_ = createDeviceInterface(kStableCPU);
STD_TORCH_CHECK(
Expand All @@ -315,7 +317,6 @@ void BetaCudaDeviceInterface::initialize(
return;
}

STD_TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
timeBase_ = avStream->time_base;
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;

Expand Down Expand Up @@ -868,13 +869,58 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
gpuFrame->format == AV_PIX_FMT_CUDA,
"Expected CUDA format frame from BETA CUDA interface");

// When rotation is active, the pre-allocated tensor has post-rotation
// dimensions, but NV12->RGB conversion needs pre-rotation dimensions.
// When there's a rotation, we save and nullify the pre-allocated tensor,
// then copy into it after rotation.
// Our current rotation implementation is a bit hacky. Once we support native
// transforms on the beta CUDA interface, rotation should be handled as part
// of the transform pipeline instead.
std::optional<torch::Tensor> savedPreAllocatedOutputTensor = std::nullopt;
if (rotation_ != Rotation::NONE && preAllocatedOutputTensor.has_value()) {
savedPreAllocatedOutputTensor = preAllocatedOutputTensor;
preAllocatedOutputTensor = std::nullopt;
}

validatePreAllocatedTensorShape(preAllocatedOutputTensor, gpuFrame);

at::cuda::CUDAStream nvdecStream =
at::cuda::getCurrentCUDAStream(device_.index());

frameOutput.data = convertNV12FrameToRGB(
gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);

if (rotation_ != Rotation::NONE) {
applyRotation(frameOutput, savedPreAllocatedOutputTensor);
}
}

void BetaCudaDeviceInterface::applyRotation(
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
int k = 0;
switch (rotation_) {
case Rotation::CCW90:
k = 1;
break;
case Rotation::ROTATE180:
k = 2;
break;
case Rotation::CW90:
k = 3;
break;
default:
STD_TORCH_CHECK(false, "Unexpected rotation value");
break;
}
// Apply rotation using torch::rot90 on the H and W dims of our HWC tensor.
// torch::rot90 returns a view, so we need to make it contiguous.
frameOutput.data = torch::rot90(frameOutput.data, k, {0, 1}).contiguous();

if (preAllocatedOutputTensor.has_value()) {
preAllocatedOutputTensor.value().copy_(frameOutput.data);
frameOutput.data = preAllocatedOutputTensor.value();
}
}

std::string BetaCudaDeviceInterface::getDetails() {
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/_core/BetaCudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "DeviceInterface.h"
#include "FFMPEGCommon.h"
#include "NVDECCache.h"
#include "Transform.h"

#include <map>
#include <memory>
Expand Down Expand Up @@ -82,6 +83,10 @@ class BetaCudaDeviceInterface : public DeviceInterface {

UniqueAVFrame transferCpuFrameToGpuNV12(UniqueAVFrame& cpuFrame);

void applyRotation(
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor);

CUvideoparser videoParser_ = nullptr;
UniqueCUvideodecoder decoder_;
CUVIDEOFORMAT videoFormat_ = {};
Expand All @@ -102,6 +107,7 @@ class BetaCudaDeviceInterface : public DeviceInterface {
bool nvcuvidAvailable_ = false;
UniqueSwsContext swsContext_;
SwsFrameContext prevSwsFrameContext_;
Rotation rotation_ = Rotation::NONE;
};

} // namespace facebook::torchcodec
Expand Down
12 changes: 9 additions & 3 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,20 +1966,26 @@ def test_cpu_fallback_no_fallback_on_cpu_device(self):
assert "No fallback required" in str(decoder.cpu_fallback)

@pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"])
def test_rotation_applied_to_frames(self, dimension_order):
@pytest.mark.parametrize(
"device",
("cpu", pytest.param("cuda:beta", marks=pytest.mark.needs_cuda)),
)
def test_rotation_applied_to_frames(self, dimension_order, device):
"""Test that rotation is correctly applied to decoded frames.

Compares frames from NASA_VIDEO_ROTATED (which has 90-degree rotation
metadata) with manually rotated frames from NASA_VIDEO.
Tests all decoding methods to ensure rotation is applied consistently.
"""
decoder = VideoDecoder(
decoder, _ = make_video_decoder(
NASA_VIDEO.path,
device=device,
stream_index=NASA_VIDEO.default_stream_index,
dimension_order=dimension_order,
)
decoder_rotated = VideoDecoder(
decoder_rotated, _ = make_video_decoder(
NASA_VIDEO_ROTATED.path,
device=device,
stream_index=NASA_VIDEO_ROTATED.default_stream_index,
dimension_order=dimension_order,
)
Expand Down
Loading