@@ -190,14 +190,12 @@ void CudaDeviceInterface::initialize(
190190 const VideoStreamOptions& videoStreamOptions,
191191 [[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
192192 const AVRational& timeBase,
193- const FrameDims& metadataDims,
194193 [[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {
195194 TORCH_CHECK (!ctx_, " FFmpeg HW device context already initialized" );
196195 TORCH_CHECK (codecContext != nullptr , " codecContext is null" );
197196
198197 videoStreamOptions_ = videoStreamOptions;
199198 timeBase_ = timeBase;
200- metadataDims_ = metadataDims;
201199
202200 // It is important for pytorch itself to create the cuda context. If ffmpeg
203201 // creates the context it may not be compatible with pytorch.
@@ -269,8 +267,8 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12(
269267 avFrame->height ,
270268 frameFormat,
271269 avFrame->sample_aspect_ratio ,
272- metadataDims_. width ,
273- metadataDims_. height ,
270+ avFrame-> width ,
271+ avFrame-> height ,
274272 outputFormat,
275273 filters.str (),
276274 timeBase_,
@@ -304,15 +302,19 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
304302 UniqueAVFrame& avFrame,
305303 FrameOutput& frameOutput,
306304 std::optional<torch::Tensor> preAllocatedOutputTensor) {
305+ // Note that CUDA does not yet support transforms, so the only possible
306+ // frame dimensions are the raw decoded frame's dimensions.
307+ auto frameDims = FrameDims (avFrame->width , avFrame->height );
308+
307309 if (preAllocatedOutputTensor.has_value ()) {
308310 auto shape = preAllocatedOutputTensor.value ().sizes ();
309311 TORCH_CHECK (
310- (shape.size () == 3 ) && (shape[0 ] == metadataDims_ .height ) &&
311- (shape[1 ] == metadataDims_ .width ) && (shape[2 ] == 3 ),
312+ (shape.size () == 3 ) && (shape[0 ] == frameDims .height ) &&
313+ (shape[1 ] == frameDims .width ) && (shape[2 ] == 3 ),
312314 " Expected tensor of shape " ,
313- metadataDims_ .height ,
315+ frameDims .height ,
314316 " x" ,
315- metadataDims_ .width ,
317+ frameDims .width ,
316318 " x3, got " ,
317319 shape);
318320 }
@@ -333,34 +335,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
333335 // whatever reason. Typically that happens if the video's encoder isn't
334336 // supported by NVDEC.
335337 //
336- // In both cases, we have a frame on the CPU, and we need a CPU device to
337- // handle it. We send the frame back to the CUDA device when we're done.
338- //
339- // TODO: Perhaps we should cache cpuInterface?
340- auto cpuInterface = createDeviceInterface (torch::kCPU );
341- TORCH_CHECK (
342- cpuInterface != nullptr , " Failed to create CPU device interface" );
343- cpuInterface->initialize (
344- nullptr ,
345- VideoStreamOptions (),
346- {},
347- timeBase_,
348- metadataDims_,
349- std::nullopt );
338+ // In both cases, we have a frame on the CPU. We send the frame back to the
339+ // CUDA device when we're done.
350340
351341 enum AVPixelFormat frameFormat =
352342 static_cast <enum AVPixelFormat>(avFrame->format );
353343
354344 FrameOutput cpuFrameOutput;
355-
356- if (frameFormat == AV_PIX_FMT_RGB24 &&
357- avFrame->width == metadataDims_.width &&
358- avFrame->height == metadataDims_.height ) {
359- // Reason 1 above. The frame is already in the format and dimensions that
360- // we need, we just need to convert it to a tensor.
345+ if (frameFormat == AV_PIX_FMT_RGB24) {
346+ // Reason 1 above. The frame is already in RGB24, we just need to convert
347+ // it to a tensor.
361348 cpuFrameOutput.data = rgbAVFrameToTensor (avFrame);
362349 } else {
363- // Reason 2 above. We need to do a full conversion.
350+ // Reason 2 above. We need to do a full conversion which requires an
351+ // actual CPU device.
352+ //
353+ // TODO: Perhaps we should cache cpuInterface?
354+ auto cpuInterface = createDeviceInterface (torch::kCPU );
355+ TORCH_CHECK (
356+ cpuInterface != nullptr , " Failed to create CPU device interface" );
357+ cpuInterface->initialize (
358+ nullptr ,
359+ VideoStreamOptions (),
360+ {},
361+ timeBase_,
362+ /* resizedOutputDims=*/ std::nullopt );
363+
364364 cpuInterface->convertAVFrameToFrameOutput (avFrame, cpuFrameOutput);
365365 }
366366
@@ -401,7 +401,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
401401 if (preAllocatedOutputTensor.has_value ()) {
402402 dst = preAllocatedOutputTensor.value ();
403403 } else {
404- dst = allocateEmptyHWCTensor (metadataDims_ , device_);
404+ dst = allocateEmptyHWCTensor (frameDims , device_);
405405 }
406406
407407 torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device_);
@@ -440,7 +440,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
440440 " cudaStreamGetFlags failed: " ,
441441 cudaGetErrorString (err));
442442
443- NppiSize oSizeROI = {metadataDims_ .width , metadataDims_ .height };
443+ NppiSize oSizeROI = {frameDims .width , frameDims .height };
444444 Npp8u* yuvData[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
445445
446446 NppStatus status;
0 commit comments