@@ -35,22 +35,19 @@ static bool g_cuda_beta = registerDeviceInterface(
3535
3636static int CUDAAPI
3737pfnSequenceCallback (void * pUserData, CUVIDEOFORMAT* videoFormat) {
38- BetaCudaDeviceInterface* decoder =
39- static_cast <BetaCudaDeviceInterface*>(pUserData);
38+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
4039 return decoder->streamPropertyChange (videoFormat);
4140}
4241
4342static int CUDAAPI
4443pfnDecodePictureCallback (void * pUserData, CUVIDPICPARAMS* picParams) {
45- BetaCudaDeviceInterface* decoder =
46- static_cast <BetaCudaDeviceInterface*>(pUserData);
44+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
4745 return decoder->frameReadyForDecoding (picParams);
4846}
4947
5048static int CUDAAPI
5149pfnDisplayPictureCallback (void * pUserData, CUVIDPARSERDISPINFO* dispInfo) {
52- BetaCudaDeviceInterface* decoder =
53- static_cast <BetaCudaDeviceInterface*>(pUserData);
50+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
5451 return decoder->frameReadyInDisplayOrder (dispInfo);
5552}
5653
@@ -112,27 +109,29 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
112109 caps.nMaxMBCount );
113110
114111 // Decoder creation parameters, taken from DALI
115- CUVIDDECODECREATEINFO decoder_info = {};
116- decoder_info.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8 ;
117- decoder_info.ChromaFormat = videoFormat->chroma_format ;
118- decoder_info.CodecType = videoFormat->codec ;
119- decoder_info.ulHeight = videoFormat->coded_height ;
120- decoder_info.ulWidth = videoFormat->coded_width ;
121- decoder_info.ulMaxHeight = videoFormat->coded_height ;
122- decoder_info.ulMaxWidth = videoFormat->coded_width ;
123- decoder_info.ulTargetHeight =
112+ CUVIDDECODECREATEINFO decoderParams = {};
113+ decoderParams.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8 ;
114+ decoderParams.ChromaFormat = videoFormat->chroma_format ;
115+ decoderParams.OutputFormat = cudaVideoSurfaceFormat_NV12;
116+ decoderParams.ulCreationFlags = cudaVideoCreate_Default;
117+ decoderParams.CodecType = videoFormat->codec ;
118+ decoderParams.ulHeight = videoFormat->coded_height ;
119+ decoderParams.ulWidth = videoFormat->coded_width ;
120+ decoderParams.ulMaxHeight = videoFormat->coded_height ;
121+ decoderParams.ulMaxWidth = videoFormat->coded_width ;
122+ decoderParams.ulTargetHeight =
124123 videoFormat->display_area .bottom - videoFormat->display_area .top ;
125- decoder_info .ulTargetWidth =
124+ decoderParams .ulTargetWidth =
126125 videoFormat->display_area .right - videoFormat->display_area .left ;
127- decoder_info .ulNumDecodeSurfaces = videoFormat->min_num_decode_surfaces ;
128- decoder_info .ulNumOutputSurfaces = 2 ;
129- decoder_info .display_area .left = videoFormat->display_area .left ;
130- decoder_info .display_area .right = videoFormat->display_area .right ;
131- decoder_info .display_area .top = videoFormat->display_area .top ;
132- decoder_info .display_area .bottom = videoFormat->display_area .bottom ;
126+ decoderParams .ulNumDecodeSurfaces = videoFormat->min_num_decode_surfaces ;
127+ decoderParams .ulNumOutputSurfaces = 2 ;
128+ decoderParams .display_area .left = videoFormat->display_area .left ;
129+ decoderParams .display_area .right = videoFormat->display_area .right ;
130+ decoderParams .display_area .top = videoFormat->display_area .top ;
131+ decoderParams .display_area .bottom = videoFormat->display_area .bottom ;
133132
134133 CUvideodecoder* decoder = new CUvideodecoder ();
135- result = cuvidCreateDecoder (decoder, &decoder_info );
134+ result = cuvidCreateDecoder (decoder, &decoderParams );
136135 TORCH_CHECK (
137136 result == CUDA_SUCCESS, " Failed to create NVDEC decoder: " , result);
138137 return UniqueCUvideodecoder (decoder, CUvideoDecoderDeleter{});
@@ -260,12 +259,19 @@ void BetaCudaDeviceInterface::initializeBSF(
260259 getFFMPEGErrorStringFromErrorCode (retVal));
261260}
262261
263- void BetaCudaDeviceInterface::initializeInterface (
262+ void BetaCudaDeviceInterface::initialize (
264263 const AVStream* avStream,
265264 const UniqueDecodingAVFormatContext& avFormatCtx) {
266265 torch::Tensor dummyTensorForCudaInitialization = torch::empty (
267266 {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
268267
268+ auto cudaDevice = torch::Device (torch::kCUDA );
269+ defaultCudaInterface_ =
270+ std::unique_ptr<DeviceInterface>(createDeviceInterface (cudaDevice));
271+ AVCodecContext dummyCodecContext = {};
272+ defaultCudaInterface_->initialize (avStream, avFormatCtx);
273+ defaultCudaInterface_->registerHardwareDeviceWithCodec (&dummyCodecContext);
274+
269275 TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
270276 timeBase_ = avStream->time_base ;
271277
@@ -422,6 +428,10 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
422428 CUVIDPARSERDISPINFO dispInfo = readyFrames_.front ();
423429 readyFrames_.pop ();
424430
431+ // TODONVDEC P1 we need to set the procParams.output_stream field to the
432+ // current CUDA stream and ensure proper synchronization. There's a related
433+ // NVDECTODO in CudaDeviceInterface.cpp where we do the necessary
434+ // synchronization for NPP.
425435 CUVIDPROCPARAMS procParams = {};
426436 procParams.progressive_frame = dispInfo.progressive_frame ;
427437 procParams.top_field_first = dispInfo.top_field_first ;
@@ -555,8 +565,6 @@ void BetaCudaDeviceInterface::flush() {
555565}
556566
557567void BetaCudaDeviceInterface::convertAVFrameToFrameOutput (
558- const VideoStreamOptions& videoStreamOptions,
559- const AVRational& timeBase,
560568 UniqueAVFrame& avFrame,
561569 FrameOutput& frameOutput,
562570 std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -567,20 +575,8 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
567575 // TODONVDEC P1: we use the 'default' cuda device interface for color
568576 // conversion. That's a temporary hack to make things work. we should abstract
569577 // the color conversion stuff separately.
570- if (!defaultCudaInterface_) {
571- auto cudaDevice = torch::Device (torch::kCUDA );
572- defaultCudaInterface_ =
573- std::unique_ptr<DeviceInterface>(createDeviceInterface (cudaDevice));
574- AVCodecContext dummyCodecContext = {};
575- defaultCudaInterface_->initializeContext (&dummyCodecContext);
576- }
577-
578578 defaultCudaInterface_->convertAVFrameToFrameOutput (
579- videoStreamOptions,
580- timeBase,
581- avFrame,
582- frameOutput,
583- preAllocatedOutputTensor);
579+ avFrame, frameOutput, preAllocatedOutputTensor);
584580}
585581
586582} // namespace facebook::torchcodec
0 commit comments