@@ -13,11 +13,21 @@ extern "C" {
1313#include < libavutil/pixdesc.h>
1414}
1515
16+ // TODONVDEC P1 Changes were made to this file to accomodate for the BETA CUDA
17+ // interface (see other TODONVDEC below). That's because the BETA CUDA interface
18+ // relies on this default CUDA interface to do the color conversion. That's
19+ // hacky, ugly, and leads to complicated code. We should refactor all this so
20+ // that an interface doesn't need to know anything about any other interface.
21+ // Note - this is more than just about the BETA CUDA interface: this default
22+ // interface already relies on the CPU interface to do software decoding when
23+ // needed, and that's already leading to similar complications.
24+
1625namespace facebook ::torchcodec {
1726namespace {
1827
19- static bool g_cuda =
20- registerDeviceInterface (torch::kCUDA , [](const torch::Device& device) {
28+ static bool g_cuda = registerDeviceInterface(
29+ DeviceInterfaceKey (torch::kCUDA ),
30+ [](const torch::Device& device) {
2131 return new CudaDeviceInterface (device);
2232 });
2333
@@ -193,13 +203,18 @@ CudaDeviceInterface::~CudaDeviceInterface() {
193203 }
194204}
195205
196- void CudaDeviceInterface::initialize (
197- AVCodecContext* codecContext,
198- const AVRational& timeBase) {
199- TORCH_CHECK (ctx_, " FFmpeg HW device has not been initialized" );
200- TORCH_CHECK (codecContext != nullptr , " codecContext is null" );
201- codecContext->hw_device_ctx = av_buffer_ref (ctx_.get ());
202- timeBase_ = timeBase;
206+ void CudaDeviceInterface::initialize (const AVStream* avStream) {
207+ TORCH_CHECK (avStream != nullptr , " avStream is null" );
208+ timeBase_ = avStream->time_base ;
209+
210+ cpuInterface_ = createDeviceInterface (torch::kCPU );
211+ TORCH_CHECK (
212+ cpuInterface_ != nullptr , " Failed to create CPU device interface" );
213+ cpuInterface_->initialize (avStream);
214+ cpuInterface_->initializeVideo (
215+ VideoStreamOptions (),
216+ {},
217+ /* resizedOutputDims=*/ std::nullopt );
203218}
204219
205220void CudaDeviceInterface::initializeVideo (
@@ -209,6 +224,13 @@ void CudaDeviceInterface::initializeVideo(
209224 videoStreamOptions_ = videoStreamOptions;
210225}
211226
227+ void CudaDeviceInterface::registerHardwareDeviceWithCodec (
228+ AVCodecContext* codecContext) {
229+ TORCH_CHECK (ctx_, " FFmpeg HW device has not been initialized" );
230+ TORCH_CHECK (codecContext != nullptr , " codecContext is null" );
231+ codecContext->hw_device_ctx = av_buffer_ref (ctx_.get ());
232+ }
233+
212234UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24 (
213235 UniqueAVFrame& avFrame) {
214236 // We need FFmpeg filters to handle those conversion cases which are not
@@ -222,6 +244,12 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(
222244 return std::move (avFrame);
223245 }
224246
247+ if (avFrame->hw_frames_ctx == nullptr ) {
248+ // TODONVDEC P2 return early for for beta interface where avFrames don't
249+ // have a hw_frames_ctx. We should get rid of this or improve the logic.
250+ return std::move (avFrame);
251+ }
252+
225253 auto hwFramesCtx =
226254 reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
227255 TORCH_CHECK (
@@ -351,19 +379,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
351379 } else {
352380 // Reason 2 above. We need to do a full conversion which requires an
353381 // actual CPU device.
354- //
355- // TODO: Perhaps we should cache cpuInterface?
356- auto cpuInterface = createDeviceInterface (torch::kCPU );
357- TORCH_CHECK (
358- cpuInterface != nullptr , " Failed to create CPU device interface" );
359- cpuInterface->initialize (
360- /* codecContext=*/ nullptr , timeBase_);
361- cpuInterface->initializeVideo (
362- VideoStreamOptions (),
363- {},
364- /* resizedOutputDims=*/ std::nullopt );
365-
366- cpuInterface->convertAVFrameToFrameOutput (avFrame, cpuFrameOutput);
382+ cpuInterface_->convertAVFrameToFrameOutput (avFrame, cpuFrameOutput);
367383 }
368384
369385 // Finally, we need to send the frame back to the GPU. Note that the
@@ -383,22 +399,23 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
383399 // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
384400 // because this is what the NPP color conversion routines expect. This SHOULD
385401 // be enforced by our call to maybeConvertAVFrameToNV12OrRGB24() above.
386- auto hwFramesCtx =
387- reinterpret_cast <AVHWFramesContext*>(avFrame-> hw_frames_ctx -> data );
388- TORCH_CHECK (
389- hwFramesCtx ! = nullptr ,
390- " The AVFrame does not have a hw_frames_ctx. "
391- " That's unexpected, please report this to the TorchCodec repo. " );
392-
393- AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
402+ // TODONVDEC P2 this can be hit from the beta interface, but there's no
403+ // hw_frames_ctx in this case. We should try to understand how that affects
404+ // this validation.
405+ AVHWFramesContext* hwFramesCtx = nullptr ;
406+ if (avFrame-> hw_frames_ctx != nullptr ) {
407+ hwFramesCtx =
408+ reinterpret_cast <AVHWFramesContext*>(avFrame-> hw_frames_ctx -> data );
409+ AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
394410
395- TORCH_CHECK (
396- actualFormat == AV_PIX_FMT_NV12,
397- " The AVFrame is " ,
398- (av_get_pix_fmt_name (actualFormat) ? av_get_pix_fmt_name (actualFormat)
399- : " unknown" ),
400- " , but we expected AV_PIX_FMT_NV12. "
401- " That's unexpected, please report this to the TorchCodec repo." );
411+ TORCH_CHECK (
412+ actualFormat == AV_PIX_FMT_NV12,
413+ " The AVFrame is " ,
414+ (av_get_pix_fmt_name (actualFormat) ? av_get_pix_fmt_name (actualFormat)
415+ : " unknown" ),
416+ " , but we expected AV_PIX_FMT_NV12. "
417+ " That's unexpected, please report this to the TorchCodec repo." );
418+ }
402419
403420 torch::Tensor& dst = frameOutput.data ;
404421 if (preAllocatedOutputTensor.has_value ()) {
@@ -418,21 +435,24 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
418435 // arbitrary, but unfortunately we know it's hardcoded to be the default
419436 // stream by FFmpeg:
420437 // https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
421- TORCH_CHECK (
422- hwFramesCtx->device_ctx != nullptr ,
423- " The AVFrame's hw_frames_ctx does not have a device_ctx. " );
424- auto cudaDeviceCtx =
425- static_cast <AVCUDADeviceContext*>(hwFramesCtx->device_ctx ->hwctx );
426- TORCH_CHECK (cudaDeviceCtx != nullptr , " The hardware context is null" );
427-
428- at::cuda::CUDAEvent nvdecDoneEvent;
429- at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
430- c10::cuda::getStreamFromExternal (cudaDeviceCtx->stream , deviceIndex);
431- nvdecDoneEvent.record (nvdecStream);
432-
433- // Don't start NPP work before NVDEC is done decoding the frame!
434438 at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream (deviceIndex);
435- nvdecDoneEvent.block (nppStream);
439+ if (hwFramesCtx) {
440+ // TODONVDEC P2 this block won't be hit from the beta interface because
441+ // there is no hwFramesCtx, but we should still make sure there's no CUDA
442+ // stream sync issue in the beta interface.
443+ TORCH_CHECK (
444+ hwFramesCtx->device_ctx != nullptr ,
445+ " The AVFrame's hw_frames_ctx does not have a device_ctx. " );
446+ auto cudaDeviceCtx =
447+ static_cast <AVCUDADeviceContext*>(hwFramesCtx->device_ctx ->hwctx );
448+ TORCH_CHECK (cudaDeviceCtx != nullptr , " The hardware context is null" );
449+ at::cuda::CUDAEvent nvdecDoneEvent;
450+ at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
451+ c10::cuda::getStreamFromExternal (cudaDeviceCtx->stream , deviceIndex);
452+ nvdecDoneEvent.record (nvdecStream);
453+ // Don't start NPP work before NVDEC is done decoding the frame!
454+ nvdecDoneEvent.block (nppStream);
455+ }
436456
437457 // Create the NPP context if we haven't yet.
438458 nppCtx_->hStream = nppStream.stream ();
0 commit comments