@@ -196,10 +196,48 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
196
196
UniqueAVFrame& avFrame,
197
197
FrameOutput& frameOutput,
198
198
std::optional<torch::Tensor> preAllocatedOutputTensor) {
199
+ // We check that avFrame->format == AV_PIX_FMT_CUDA. This only ensures the
200
+ // AVFrame is on GPU memory. It can be on CPU memory if the video isn't
201
+ // supported by NVDEC for whatever reason: NVDEC falls back to CPU decoding in
202
+ // this case, and our check fails.
203
+ // TODO: we could send the frame back into the CPU path, and rely on
204
+ // swscale/filtergraph to run the color conversion to properly output the
205
+ // frame.
199
206
TORCH_CHECK (
200
207
avFrame->format == AV_PIX_FMT_CUDA,
201
- " Expected format to be AV_PIX_FMT_CUDA, got " +
202
- std::string (av_get_pix_fmt_name ((AVPixelFormat)avFrame->format )));
208
+ " Expected format to be AV_PIX_FMT_CUDA, got " ,
209
+ (av_get_pix_fmt_name ((AVPixelFormat)avFrame->format )
210
+ ? av_get_pix_fmt_name ((AVPixelFormat)avFrame->format )
211
+ : " unknown" ),
212
+ " . When that happens, it is probably because the video is not supported by NVDEC. "
213
+ " Try using the CPU device instead. "
214
+ " If the video is 10bit, we are tracking 10bit support in "
215
+ " https://github.com/pytorch/torchcodec/issues/776" );
216
+
217
+ // Above we checked that the AVFrame was on GPU, but that's not enough, we
218
+ // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
219
+ // because this is what the NPP color conversion routines expect.
220
+ // TODO: we should investigate how to can perform color conversion for
221
+ // non-8bit videos. This is supported on CPU.
222
+ TORCH_CHECK (
223
+ avFrame->hw_frames_ctx != nullptr ,
224
+ " The AVFrame does not have a hw_frames_ctx. "
225
+ " That's unexpected, please report this to the TorchCodec repo." );
226
+
227
+ auto hwFramesCtx =
228
+ reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
229
+ AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
230
+ TORCH_CHECK (
231
+ actualFormat == AV_PIX_FMT_NV12,
232
+ " The AVFrame is " ,
233
+ (av_get_pix_fmt_name (actualFormat) ? av_get_pix_fmt_name (actualFormat)
234
+ : " unknown" ),
235
+ " , but we expected AV_PIX_FMT_NV12. This typically happens when "
236
+ " the video isn't 8bit, which is not supported on CUDA at the moment. "
237
+ " Try using the CPU device instead. "
238
+ " If the video is 10bit, we are tracking 10bit support in "
239
+ " https://github.com/pytorch/torchcodec/issues/776" );
240
+
203
241
auto frameDims =
204
242
getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
205
243
int height = frameDims.height ;
0 commit comments