Skip to content

Commit cc15044

Browse files
authored
Support rotation on beta cuda (#1235)
1 parent 2d1b5c6 commit cc15044

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ void BetaCudaDeviceInterface::initialize(
301301
const AVStream* avStream,
302302
const UniqueDecodingAVFormatContext& avFormatCtx,
303303
[[maybe_unused]] const SharedAVCodecContext& codecContext) {
304+
STD_TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
305+
rotation_ = rotationFromDegrees(getRotationFromStream(avStream));
304306
if (!nvcuvidAvailable_ || !nativeNVDECSupport(device_, codecContext)) {
305307
cpuFallback_ = createDeviceInterface(kStableCPU);
306308
STD_TORCH_CHECK(
@@ -314,7 +316,6 @@ void BetaCudaDeviceInterface::initialize(
314316
return;
315317
}
316318

317-
STD_TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
318319
timeBase_ = avStream->time_base;
319320
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
320321

@@ -867,12 +868,54 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
867868
gpuFrame->format == AV_PIX_FMT_CUDA,
868869
"Expected CUDA format frame from BETA CUDA interface");
869870

870-
validatePreAllocatedTensorShape(preAllocatedOutputTensor, gpuFrame);
871-
872871
cudaStream_t nvdecStream = getCurrentCudaStream(device_.index());
873872

874-
frameOutput.data = convertNV12FrameToRGB(
875-
gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
873+
if (rotation_ == Rotation::NONE) {
874+
validatePreAllocatedTensorShape(preAllocatedOutputTensor, gpuFrame);
875+
frameOutput.data = convertNV12FrameToRGB(
876+
gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
877+
} else {
878+
// preAllocatedOutputTensor has post-rotation dimensions, but NV12->RGB
879+
// conversion outputs pre-rotation dimensions, so we can't use it as the
880+
// conversion destination or validate it against the frame shape.
881+
// Once we support native transforms on the beta CUDA interface, rotation
882+
// should be handled as part of the transform pipeline instead.
883+
frameOutput.data = convertNV12FrameToRGB(
884+
gpuFrame,
885+
device_,
886+
nppCtx_,
887+
nvdecStream,
888+
/*preAllocatedOutputTensor=*/std::nullopt);
889+
applyRotation(frameOutput, preAllocatedOutputTensor);
890+
}
891+
}
892+
893+
void BetaCudaDeviceInterface::applyRotation(
894+
FrameOutput& frameOutput,
895+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
896+
int k = 0;
897+
switch (rotation_) {
898+
case Rotation::CCW90:
899+
k = 1;
900+
break;
901+
case Rotation::ROTATE180:
902+
k = 2;
903+
break;
904+
case Rotation::CW90:
905+
k = 3;
906+
break;
907+
default:
908+
STD_TORCH_CHECK(false, "Unexpected rotation value");
909+
break;
910+
}
911+
// Apply rotation using torch::rot90 on the H and W dims of our HWC tensor.
912+
// torch::rot90 returns a view, so we need to make it contiguous.
913+
frameOutput.data = torch::rot90(frameOutput.data, k, {0, 1}).contiguous();
914+
915+
if (preAllocatedOutputTensor.has_value()) {
916+
preAllocatedOutputTensor.value().copy_(frameOutput.data);
917+
frameOutput.data = preAllocatedOutputTensor.value();
918+
}
876919
}
877920

878921
std::string BetaCudaDeviceInterface::getDetails() {

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "DeviceInterface.h"
2121
#include "FFMPEGCommon.h"
2222
#include "NVDECCache.h"
23+
#include "Transform.h"
2324

2425
#include <map>
2526
#include <memory>
@@ -82,6 +83,10 @@ class BetaCudaDeviceInterface : public DeviceInterface {
8283

8384
UniqueAVFrame transferCpuFrameToGpuNV12(UniqueAVFrame& cpuFrame);
8485

86+
void applyRotation(
87+
FrameOutput& frameOutput,
88+
std::optional<torch::Tensor> preAllocatedOutputTensor);
89+
8590
CUvideoparser videoParser_ = nullptr;
8691
UniqueCUvideodecoder decoder_;
8792
CUVIDEOFORMAT videoFormat_ = {};
@@ -102,6 +107,7 @@ class BetaCudaDeviceInterface : public DeviceInterface {
102107
bool nvcuvidAvailable_ = false;
103108
UniqueSwsContext swsContext_;
104109
SwsFrameContext prevSwsFrameContext_;
110+
Rotation rotation_ = Rotation::NONE;
105111
};
106112

107113
} // namespace facebook::torchcodec

test/test_decoders.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,20 +1966,28 @@ def test_cpu_fallback_no_fallback_on_cpu_device(self):
19661966
assert "No fallback required" in str(decoder.cpu_fallback)
19671967

19681968
@pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"])
1969-
def test_rotation_applied_to_frames(self, dimension_order):
1969+
@pytest.mark.parametrize(
1970+
# We are skipping over cuda because we do not support rotation metadata
1971+
# for the FFmpeg CUDA interface.
1972+
"device",
1973+
("cpu", pytest.param("cuda:beta", marks=pytest.mark.needs_cuda)),
1974+
)
1975+
def test_rotation_applied_to_frames(self, dimension_order, device):
19701976
"""Test that rotation is correctly applied to decoded frames.
19711977
19721978
Compares frames from NASA_VIDEO_ROTATED (which has 90-degree rotation
19731979
metadata) with manually rotated frames from NASA_VIDEO.
19741980
Tests all decoding methods to ensure rotation is applied consistently.
19751981
"""
1976-
decoder = VideoDecoder(
1982+
decoder, _ = make_video_decoder(
19771983
NASA_VIDEO.path,
1984+
device=device,
19781985
stream_index=NASA_VIDEO.default_stream_index,
19791986
dimension_order=dimension_order,
19801987
)
1981-
decoder_rotated = VideoDecoder(
1988+
decoder_rotated, _ = make_video_decoder(
19821989
NASA_VIDEO_ROTATED.path,
1990+
device=device,
19831991
stream_index=NASA_VIDEO_ROTATED.default_stream_index,
19841992
dimension_order=dimension_order,
19851993
)

0 commit comments

Comments
 (0)