Skip to content

Commit 941d6a3

Browse files
committed
Fix CUDA?
1 parent 404b2e4 commit 941d6a3

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,14 @@ void convertAVFrameToDecodedOutputOnCuda(
189189
VideoDecoder::RawDecodedOutput& rawOutput,
190190
VideoDecoder::DecodedOutput& output,
191191
std::optional<torch::Tensor> preAllocatedOutputTensor) {
192-
AVFrame* src = rawOutput.frame.get();
192+
AVFrame* avFrame = rawOutput.avFrame.get();
193193

194194
TORCH_CHECK(
195-
src->format == AV_PIX_FMT_CUDA,
195+
avFrame->format == AV_PIX_FMT_CUDA,
196196
"Expected format to be AV_PIX_FMT_CUDA, got " +
197-
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
197+
std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format)));
198198
auto frameDims =
199-
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *src);
199+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame);
200200
int height = frameDims.height;
201201
int width = frameDims.width;
202202
torch::Tensor& dst = output.frame;
@@ -220,21 +220,21 @@ void convertAVFrameToDecodedOutputOnCuda(
220220
c10::cuda::CUDAGuard deviceGuard(device);
221221

222222
NppiSize oSizeROI = {width, height};
223-
Npp8u* input[2] = {src->data[0], src->data[1]};
223+
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
224224

225225
auto start = std::chrono::high_resolution_clock::now();
226226
NppStatus status;
227-
if (src->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
227+
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
228228
status = nppiNV12ToRGB_709CSC_8u_P2C3R(
229229
input,
230-
src->linesize[0],
230+
avFrame->linesize[0],
231231
static_cast<Npp8u*>(dst.data_ptr()),
232232
dst.stride(0),
233233
oSizeROI);
234234
} else {
235235
status = nppiNV12ToRGB_8u_P2C3R(
236236
input,
237-
src->linesize[0],
237+
avFrame->linesize[0],
238238
static_cast<Npp8u*>(dst.data_ptr()),
239239
dst.stride(0),
240240
oSizeROI);

0 commit comments

Comments
 (0)