Skip to content

Commit ce5667d

Browse files
authored
BETA CUDA interface: properly set NVDEC stream and wait on it (#930)
1 parent 8f75c54 commit ce5667d

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// This source code is licensed under the BSD-style license found in the
55
// LICENSE file in the root directory of this source tree.
66

7+
#include <c10/cuda/CUDAStream.h>
78
#include <torch/types.h>
89
#include <mutex>
910
#include <vector>
@@ -479,14 +480,18 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
479480
CUVIDPARSERDISPINFO dispInfo = readyFrames_.front();
480481
readyFrames_.pop();
481482

482-
// TODONVDEC P1 we need to set the procParams.output_stream field to the
483-
// current CUDA stream and ensure proper synchronization. There's a related
484-
// NVDECTODO in CudaDeviceInterface.cpp where we do the necessary
485-
// synchronization for NPP.
486483
CUVIDPROCPARAMS procParams = {};
487484
procParams.progressive_frame = dispInfo.progressive_frame;
488485
procParams.top_field_first = dispInfo.top_field_first;
489486
procParams.unpaired_field = dispInfo.repeat_first_field < 0;
487+
// We set the NVDEC stream to the current stream. It will be waited upon by
488+
// the NPP stream before any color conversion. Currently, that syncing logic
489+
// is in the default interface.
490+
// Re types: we get a cudaStream_t from PyTorch but it's interchangeable with
491+
// CUstream
492+
procParams.output_stream = reinterpret_cast<CUstream>(
493+
at::cuda::getCurrentCUDAStream(device_.index()).stream());
494+
490495
CUdeviceptr framePtr = 0;
491496
unsigned int pitch = 0;
492497

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -426,35 +426,36 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
426426
dst = allocateEmptyHWCTensor(frameDims, device_);
427427
}
428428

429-
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
430-
431-
// Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
432-
// NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
433-
// We will be waiting for this event to complete before calling the NPP
434-
// functions, to ensure NVDEC has finished decoding the frame before running
435-
// the NPP color-conversion.
436-
// Note that our code is generic and assumes that the NVDEC's stream can be
437-
// arbitrary, but unfortunately we know it's hardcoded to be the default
438-
// stream by FFmpeg:
429+
// We need to make sure NVDEC has finished decoding a frame before
430+
// color-converting it with NPP.
431+
// So we make the NPP stream wait for NVDEC to finish.
432+
// If we're in the default CUDA interface, we figure out the NVDEC stream from
433+
// the avFrame's hardware context. But in reality, we know that this stream is
434+
// hardcoded to be the default stream by FFmpeg:
439435
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
440-
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
436+
// If we're in the BETA CUDA interface, we know the NVDEC stream was set with
437+
// getCurrentCUDAStream(), so it's the same as the nppStream.
438+
at::cuda::CUDAStream nppStream =
439+
at::cuda::getCurrentCUDAStream(device_.index());
440+
// We can't create a CUDAStream without assigning it a value so we initialize
441+
// it to the nppStream, which is valid for the BETA interface.
442+
at::cuda::CUDAStream nvdecStream = nppStream;
441443
if (hwFramesCtx) {
442-
// TODONVDEC P2 this block won't be hit from the beta interface because
443-
// there is no hwFramesCtx, but we should still make sure there's no CUDA
444-
// stream sync issue in the beta interface.
444+
// Default interface path
445445
TORCH_CHECK(
446446
hwFramesCtx->device_ctx != nullptr,
447447
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
448448
auto cudaDeviceCtx =
449449
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
450450
TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null");
451-
at::cuda::CUDAEvent nvdecDoneEvent;
452-
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
453-
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
454-
nvdecDoneEvent.record(nvdecStream);
455-
// Don't start NPP work before NVDEC is done decoding the frame!
456-
nvdecDoneEvent.block(nppStream);
451+
nvdecStream = // That's always the default stream. Sad.
452+
c10::cuda::getStreamFromExternal(
453+
cudaDeviceCtx->stream, device_.index());
457454
}
455+
// Don't start NPP work before NVDEC is done decoding the frame!
456+
at::cuda::CUDAEvent nvdecDoneEvent;
457+
nvdecDoneEvent.record(nvdecStream);
458+
nvdecDoneEvent.block(nppStream);
458459

459460
// Create the NPP context if we haven't yet.
460461
nppCtx_->hStream = nppStream.stream();

0 commit comments

Comments
 (0)