@@ -301,6 +301,8 @@ void BetaCudaDeviceInterface::initialize(
301301 const AVStream* avStream,
302302 const UniqueDecodingAVFormatContext& avFormatCtx,
303303 [[maybe_unused]] const SharedAVCodecContext& codecContext) {
304+ STD_TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
305+ rotation_ = rotationFromDegrees (getRotationFromStream (avStream));
304306 if (!nvcuvidAvailable_ || !nativeNVDECSupport (device_, codecContext)) {
305307 cpuFallback_ = createDeviceInterface (kStableCPU );
306308 STD_TORCH_CHECK (
@@ -314,7 +316,6 @@ void BetaCudaDeviceInterface::initialize(
314316 return ;
315317 }
316318
317- STD_TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
318319 timeBase_ = avStream->time_base ;
319320 frameRateAvgFromFFmpeg_ = avStream->r_frame_rate ;
320321
@@ -867,12 +868,54 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
867868 gpuFrame->format == AV_PIX_FMT_CUDA,
868869 " Expected CUDA format frame from BETA CUDA interface" );
869870
870- validatePreAllocatedTensorShape (preAllocatedOutputTensor, gpuFrame);
871-
872871 cudaStream_t nvdecStream = getCurrentCudaStream (device_.index ());
873872
874- frameOutput.data = convertNV12FrameToRGB (
875- gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
873+ if (rotation_ == Rotation::NONE) {
874+ validatePreAllocatedTensorShape (preAllocatedOutputTensor, gpuFrame);
875+ frameOutput.data = convertNV12FrameToRGB (
876+ gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
877+ } else {
878+ // preAllocatedOutputTensor has post-rotation dimensions, but NV12->RGB
879+ // conversion outputs pre-rotation dimensions, so we can't use it as the
880+ // conversion destination or validate it against the frame shape.
881+ // Once we support native transforms on the beta CUDA interface, rotation
882+ // should be handled as part of the transform pipeline instead.
883+ frameOutput.data = convertNV12FrameToRGB (
884+ gpuFrame,
885+ device_,
886+ nppCtx_,
887+ nvdecStream,
888+ /* preAllocatedOutputTensor=*/ std::nullopt );
889+ applyRotation (frameOutput, preAllocatedOutputTensor);
890+ }
891+ }
892+
893+ void BetaCudaDeviceInterface::applyRotation (
894+ FrameOutput& frameOutput,
895+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
896+ int k = 0 ;
897+ switch (rotation_) {
898+ case Rotation::CCW90:
899+ k = 1 ;
900+ break ;
901+ case Rotation::ROTATE180:
902+ k = 2 ;
903+ break ;
904+ case Rotation::CW90:
905+ k = 3 ;
906+ break ;
907+ default :
908+ STD_TORCH_CHECK (false , " Unexpected rotation value" );
909+ break ;
910+ }
911+ // Apply rotation using torch::rot90 on the H and W dims of our HWC tensor.
912+ // torch::rot90 returns a view, so we need to make it contiguous.
913+ frameOutput.data = torch::rot90 (frameOutput.data , k, {0 , 1 }).contiguous ();
914+
915+ if (preAllocatedOutputTensor.has_value ()) {
916+ preAllocatedOutputTensor.value ().copy_ (frameOutput.data );
917+ frameOutput.data = preAllocatedOutputTensor.value ();
918+ }
876919}
877920
878921std::string BetaCudaDeviceInterface::getDetails () {
0 commit comments