@@ -199,12 +199,121 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
199
199
return ;
200
200
}
201
201
202
+ std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext (
203
+ const VideoStreamOptions& videoStreamOptions,
204
+ const UniqueAVFrame& avFrame,
205
+ const AVRational& timeBase) {
206
+ // We need FFmpeg filters to handle those conversion cases which are not
207
+ // directly implemented in CUDA or CPU device interface (in case of a
208
+ // fallback).
209
+ enum AVPixelFormat frameFormat =
210
+ static_cast <enum AVPixelFormat>(avFrame->format );
211
+
212
+ // Input frame is on CPU, we will just pass it to CPU device interface, so
213
+ // skipping filters context as CPU device interface will handle everythong for
214
+ // us.
215
+ if (avFrame->format != AV_PIX_FMT_CUDA) {
216
+ return nullptr ;
217
+ }
218
+
219
+ TORCH_CHECK (
220
+ avFrame->hw_frames_ctx != nullptr ,
221
+ " The AVFrame does not have a hw_frames_ctx. "
222
+ " That's unexpected, please report this to the TorchCodec repo." );
223
+
224
+ auto hwFramesCtx =
225
+ reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
226
+ AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
227
+
228
+ // NV12 conversion is implemented directly with NPP, no need for filters.
229
+ if (actualFormat == AV_PIX_FMT_NV12) {
230
+ return nullptr ;
231
+ }
232
+
233
+ auto frameDims =
234
+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
235
+ int height = frameDims.height ;
236
+ int width = frameDims.width ;
237
+
238
+ AVPixelFormat outputFormat;
239
+ std::stringstream filters;
240
+
241
+ unsigned version_int = avfilter_version ();
242
+ if (version_int < AV_VERSION_INT (8 , 0 , 103 )) {
243
+ // Color conversion support ('format=' option) was added to scale_cuda from
244
+ // n5.0. With the earlier version of ffmpeg we have no choice but use CPU
245
+ // filters. See:
246
+ // https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
247
+ outputFormat = AV_PIX_FMT_RGB24;
248
+
249
+ filters << " hwdownload,format=" << av_pix_fmt_desc_get (actualFormat)->name ;
250
+ filters << " ,scale=" << width << " :" << height;
251
+ filters << " :sws_flags=bilinear" ;
252
+ } else {
253
+ // Actual output color format will be set via filter options
254
+ outputFormat = AV_PIX_FMT_CUDA;
255
+
256
+ filters << " scale_cuda=" << width << " :" << height;
257
+ filters << " :format=nv12:interp_algo=bilinear" ;
258
+ }
259
+
260
+ return std::make_unique<FiltersContext>(
261
+ avFrame->width ,
262
+ avFrame->height ,
263
+ frameFormat,
264
+ avFrame->sample_aspect_ratio ,
265
+ width,
266
+ height,
267
+ outputFormat,
268
+ filters.str (),
269
+ timeBase,
270
+ av_buffer_ref (avFrame->hw_frames_ctx ));
271
+ }
272
+
202
273
void CudaDeviceInterface::convertAVFrameToFrameOutput (
203
274
const VideoStreamOptions& videoStreamOptions,
204
275
[[maybe_unused]] const AVRational& timeBase,
205
- UniqueAVFrame& avFrame ,
276
+ UniqueAVFrame& avInputFrame ,
206
277
FrameOutput& frameOutput,
207
278
std::optional<torch::Tensor> preAllocatedOutputTensor) {
279
+ std::unique_ptr<FiltersContext> newFiltersContext =
280
+ initializeFiltersContext (videoStreamOptions, avInputFrame, timeBase);
281
+ UniqueAVFrame avFilteredFrame;
282
+ if (newFiltersContext) {
283
+ // We need to compare the current filter context with our previous filter
284
+ // context. If they are different, then we need to re-create a filter
285
+ // graph. We create a filter graph late so that we don't have to depend
286
+ // on the unreliable metadata in the header. And we sometimes re-create
287
+ // it because it's possible for frame resolution to change mid-stream.
288
+ // Finally, we want to reuse the filter graph as much as possible for
289
+ // performance reasons.
290
+ if (!filterGraph_ || filtersContext_ != newFiltersContext) {
291
+ filterGraph_ =
292
+ std::make_unique<FilterGraph>(*newFiltersContext, videoStreamOptions);
293
+ filtersContext_ = std::move (newFiltersContext);
294
+ }
295
+ avFilteredFrame = filterGraph_->convert (avInputFrame);
296
+
297
+ // If this check fails it means the frame wasn't
298
+ // reshaped to its expected dimensions by filtergraph.
299
+ TORCH_CHECK (
300
+ (avFilteredFrame->width == filtersContext_->outputWidth ) &&
301
+ (avFilteredFrame->height == filtersContext_->outputHeight ),
302
+ " Expected frame from filter graph of " ,
303
+ filtersContext_->outputWidth ,
304
+ " x" ,
305
+ filtersContext_->outputHeight ,
306
+ " , got " ,
307
+ avFilteredFrame->width ,
308
+ " x" ,
309
+ avFilteredFrame->height );
310
+ }
311
+
312
+ UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame;
313
+
314
+ // The filtered frame might be on CPU if CPU fallback has happenned on filter
315
+ // graph level. For example, that's how we handle color format conversion
316
+ // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet.
208
317
if (avFrame->format != AV_PIX_FMT_CUDA) {
209
318
// The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
210
319
// the GPU. In this branch, the frame is on the CPU: this is what NVDEC
@@ -232,8 +341,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
232
341
// Above we checked that the AVFrame was on GPU, but that's not enough, we
233
342
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
234
343
// because this is what the NPP color conversion routines expect.
235
- // TODO: we should investigate how to can perform color conversion for
236
- // non-8bit videos. This is supported on CPU.
237
344
TORCH_CHECK (
238
345
avFrame->hw_frames_ctx != nullptr ,
239
346
" The AVFrame does not have a hw_frames_ctx. "
@@ -242,16 +349,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
242
349
auto hwFramesCtx =
243
350
reinterpret_cast <AVHWFramesContext*>(avFrame->hw_frames_ctx ->data );
244
351
AVPixelFormat actualFormat = hwFramesCtx->sw_format ;
352
+
245
353
TORCH_CHECK (
246
354
actualFormat == AV_PIX_FMT_NV12,
247
355
" The AVFrame is " ,
248
356
(av_get_pix_fmt_name (actualFormat) ? av_get_pix_fmt_name (actualFormat)
249
357
: " unknown" ),
250
- " , but we expected AV_PIX_FMT_NV12. This typically happens when "
251
- " the video isn't 8bit, which is not supported on CUDA at the moment. "
252
- " Try using the CPU device instead. "
253
- " If the video is 10bit, we are tracking 10bit support in "
254
- " https://github.com/pytorch/torchcodec/issues/776" );
358
+ " , but we expected AV_PIX_FMT_NV12. "
359
+ " That's unexpected, please report this to the TorchCodec repo." );
255
360
256
361
auto frameDims =
257
362
getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
0 commit comments