Skip to content

Commit 8e73bcf

Browse files
committed
Add TODOs and more explicit initialization
1 parent 70873bf commit 8e73bcf

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,27 +112,29 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
112112
caps.nMaxMBCount);
113113

114114
// Decoder creation parameters, taken from DALI
115-
CUVIDDECODECREATEINFO decoder_info = {};
116-
decoder_info.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8;
117-
decoder_info.ChromaFormat = videoFormat->chroma_format;
118-
decoder_info.CodecType = videoFormat->codec;
119-
decoder_info.ulHeight = videoFormat->coded_height;
120-
decoder_info.ulWidth = videoFormat->coded_width;
121-
decoder_info.ulMaxHeight = videoFormat->coded_height;
122-
decoder_info.ulMaxWidth = videoFormat->coded_width;
123-
decoder_info.ulTargetHeight =
115+
CUVIDDECODECREATEINFO decoderParams = {};
116+
decoderParams.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8;
117+
decoderParams.ChromaFormat = videoFormat->chroma_format;
118+
decoderParams.OutputFormat = cudaVideoSurfaceFormat_NV12;
119+
decoderParams.ulCreationFlags = cudaVideoCreate_Default;
120+
decoderParams.CodecType = videoFormat->codec;
121+
decoderParams.ulHeight = videoFormat->coded_height;
122+
decoderParams.ulWidth = videoFormat->coded_width;
123+
decoderParams.ulMaxHeight = videoFormat->coded_height;
124+
decoderParams.ulMaxWidth = videoFormat->coded_width;
125+
decoderParams.ulTargetHeight =
124126
videoFormat->display_area.bottom - videoFormat->display_area.top;
125-
decoder_info.ulTargetWidth =
127+
decoderParams.ulTargetWidth =
126128
videoFormat->display_area.right - videoFormat->display_area.left;
127-
decoder_info.ulNumDecodeSurfaces = videoFormat->min_num_decode_surfaces;
128-
decoder_info.ulNumOutputSurfaces = 2;
129-
decoder_info.display_area.left = videoFormat->display_area.left;
130-
decoder_info.display_area.right = videoFormat->display_area.right;
131-
decoder_info.display_area.top = videoFormat->display_area.top;
132-
decoder_info.display_area.bottom = videoFormat->display_area.bottom;
129+
decoderParams.ulNumDecodeSurfaces = videoFormat->min_num_decode_surfaces;
130+
decoderParams.ulNumOutputSurfaces = 2;
131+
decoderParams.display_area.left = videoFormat->display_area.left;
132+
decoderParams.display_area.right = videoFormat->display_area.right;
133+
decoderParams.display_area.top = videoFormat->display_area.top;
134+
decoderParams.display_area.bottom = videoFormat->display_area.bottom;
133135

134136
CUvideodecoder* decoder = new CUvideodecoder();
135-
result = cuvidCreateDecoder(decoder, &decoder_info);
137+
result = cuvidCreateDecoder(decoder, &decoderParams);
136138
TORCH_CHECK(
137139
result == CUDA_SUCCESS, "Failed to create NVDEC decoder: ", result);
138140
return UniqueCUvideodecoder(decoder, CUvideoDecoderDeleter{});
@@ -356,6 +358,10 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
356358
CUVIDPARSERDISPINFO dispInfo = readyFrames_.front();
357359
readyFrames_.pop();
358360

361+
// TODONVDEC P1 we need to set the procParams.output_stream field to the
362+
// current CUDA stream and ensure proper synchronization. There's a related
363+
// NVDECTODO in CudaDeviceInterface.cpp where we do the necessary
364+
// synchronization for NPP.
359365
CUVIDPROCPARAMS procParams = {};
360366
procParams.progressive_frame = dispInfo.progressive_frame;
361367
procParams.top_field_first = dispInfo.top_field_first;

0 commit comments

Comments
 (0)