Skip to content

Commit 1bd7d9c

Browse files
committed
properly set NVDEC stream and wait on it
1 parent b60d50f commit 1bd7d9c

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// This source code is licensed under the BSD-style license found in the
55
// LICENSE file in the root directory of this source tree.
66

7+
#include <c10/cuda/CUDAStream.h>
78
#include <torch/types.h>
89
#include <mutex>
910
#include <vector>
@@ -459,14 +460,18 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
459460
CUVIDPARSERDISPINFO dispInfo = readyFrames_.front();
460461
readyFrames_.pop();
461462

462-
// TODONVDEC P1 we need to set the procParams.output_stream field to the
463-
// current CUDA stream and ensure proper synchronization. There's a related
464-
// NVDECTODO in CudaDeviceInterface.cpp where we do the necessary
465-
// synchronization for NPP.
466463
CUVIDPROCPARAMS procParams = {};
467464
procParams.progressive_frame = dispInfo.progressive_frame;
468465
procParams.top_field_first = dispInfo.top_field_first;
469466
procParams.unpaired_field = dispInfo.repeat_first_field < 0;
467+
// We set the NVDEC stream to the current stream. It will be waited upon by
468+
// the NPP stream before any color conversion. Currently, that syncing logic
469+
// is in the default interface.
470+
// Re types: we get a cudaStream_t from PyTorch but it's interchangeable with
471+
// CUstream
472+
procParams.output_stream = reinterpret_cast<CUstream>(
473+
at::cuda::getCurrentCUDAStream(device_.index()).stream());
474+
470475
CUdeviceptr framePtr = 0;
471476
unsigned int pitch = 0;
472477

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -426,35 +426,36 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
426426
dst = allocateEmptyHWCTensor(frameDims, device_);
427427
}
428428

429-
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
430-
431-
// Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
432-
// NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
433-
// We will be waiting for this event to complete before calling the NPP
434-
// functions, to ensure NVDEC has finished decoding the frame before running
435-
// the NPP color-conversion.
436-
// Note that our code is generic and assumes that the NVDEC's stream can be
437-
// arbitrary, but unfortunately we know it's hardcoded to be the default
438-
// stream by FFmpeg:
429+
// We need to make sure NVDEC has finished decoding a frame before
430+
// color-converting it with NPP.
431+
// So we make the NPP stream wait for NVDEC to finish.
432+
// If we're in the default CUDA interface, we figure out the NVDEC stream from
433+
// the avFrame's hardware context. But in reality, we know that this stream is
434+
// hardcoded to be the default stream by FFmpeg:
439435
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
440-
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
436+
// If we're in the BETA CUDA interface, we know the NVDEC stream was set with
437+
// getCurrentCUDAStream(), so it's the same as the nppStream.
438+
at::cuda::CUDAStream nppStream =
439+
at::cuda::getCurrentCUDAStream(device_.index());
440+
// We can't create a CUDAStream without assigning it a value so we initialize
441+
// it to the nppStream, which is valid for the BETA interface.
442+
at::cuda::CUDAStream nvdecStream = nppStream;
441443
if (hwFramesCtx) {
442-
// TODONVDEC P2 this block won't be hit from the beta interface because
443-
// there is no hwFramesCtx, but we should still make sure there's no CUDA
444-
// stream sync issue in the beta interface.
444+
// Default interface path
445445
TORCH_CHECK(
446446
hwFramesCtx->device_ctx != nullptr,
447447
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
448448
auto cudaDeviceCtx =
449449
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
450450
TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null");
451-
at::cuda::CUDAEvent nvdecDoneEvent;
452-
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
453-
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
454-
nvdecDoneEvent.record(nvdecStream);
455-
// Don't start NPP work before NVDEC is done decoding the frame!
456-
nvdecDoneEvent.block(nppStream);
451+
nvdecStream = // That's always the default stream. Sad.
452+
c10::cuda::getStreamFromExternal(
453+
cudaDeviceCtx->stream, device_.index());
457454
}
455+
// Don't start NPP work before NVDEC is done decoding the frame!
456+
at::cuda::CUDAEvent nvdecDoneEvent;
457+
nvdecDoneEvent.record(nvdecStream);
458+
nvdecDoneEvent.block(nppStream);
458459

459460
// Create the NPP context if we haven't yet.
460461
nppCtx_->hStream = nppStream.stream();

0 commit comments

Comments
 (0)