|
5 | 5 | #include <mutex> |
6 | 6 |
|
7 | 7 | #include "src/torchcodec/_core/Cache.h" |
| 8 | +#include "src/torchcodec/_core/CpuDeviceInterface.h" |
8 | 9 | #include "src/torchcodec/_core/CudaDeviceInterface.h" |
9 | 10 | #include "src/torchcodec/_core/FFMPEGCommon.h" |
10 | 11 |
|
@@ -230,7 +231,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( |
230 | 231 | reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data); |
231 | 232 | AVPixelFormat actualFormat = hwFramesCtx->sw_format; |
232 | 233 |
|
233 | | - // NV12 conversion is implemented directly with NPP, no need for filters. |
| 234 | + // If the frame is already in NV12 format, we don't need to do anything. |
234 | 235 | if (actualFormat == AV_PIX_FMT_NV12) { |
235 | 236 | return std::move(avFrame); |
236 | 237 | } |
@@ -310,35 +311,64 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( |
310 | 311 | UniqueAVFrame& avFrame, |
311 | 312 | FrameOutput& frameOutput, |
312 | 313 | std::optional<torch::Tensor> preAllocatedOutputTensor) { |
| 314 | + if (preAllocatedOutputTensor.has_value()) { |
| 315 | + auto shape = preAllocatedOutputTensor.value().sizes(); |
| 316 | + TORCH_CHECK( |
| 317 | + (shape.size() == 3) && (shape[0] == outputDims_.height) && |
| 318 | + (shape[1] == outputDims_.width) && (shape[2] == 3), |
| 319 | + "Expected tensor of shape ", |
| 320 | + outputDims_.height, |
| 321 | + "x", |
| 322 | + outputDims_.width, |
| 323 | + "x3, got ", |
| 324 | + shape); |
| 325 | + } |
| 326 | + |
| 327 | + // All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by |
| 328 | + // converting them to NV12. |
313 | 329 | avFrame = maybeConvertAVFrameToNV12(avFrame); |
314 | 330 |
|
315 | | - // The filtered frame might be on CPU if CPU fallback has happenned on filter |
316 | | - // graph level. For example, that's how we handle color format conversion |
317 | | - // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet. |
318 | 331 | if (avFrame->format != AV_PIX_FMT_CUDA) { |
319 | 332 | // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on |
320 | | - // the GPU. In this branch, the frame is on the CPU: this is what NVDEC |
321 | | - // gives us if it wasn't able to decode a frame, for whatever reason. |
322 | | - // Typically that happens if the video's encoder isn't supported by NVDEC. |
323 | | - // Below, we choose to convert the frame's color-space using the CPU |
324 | | - // codepath, and send it back to the GPU at the very end. |
| 333 | + // the GPU. In this branch, the frame is on the CPU. There are two possible |
| 334 | + // reasons: |
| 335 | + // |
| 336 | + // 1. During maybeConvertAVFrameToNV12(), we had a non-NV12 format frame |
| 337 | + // and we're on FFmpeg 4.4 or earlier. In such cases, we had to use CPU |
| 338 | + // filters and we just converted the frame to RGB24. |
| 339 | + // 2. This is what NVDEC gave us if it wasn't able to decode a frame, for |
| 340 | + // whatever reason. Typically that happens if the video's encoder isn't |
| 341 | + // supported by NVDEC. |
325 | 342 | // |
326 | | - // TODO: A possibly better solution would be to send the frame to the GPU |
327 | | - // first, and do the color conversion there. |
| 343 | + // In both cases, we have a frame on the CPU, and we need a CPU device to |
| 344 | + // handle it. We send the frame back to the CUDA device when we're done. |
328 | 345 | // |
329 | | - // TODO: If we're going to keep this around, we should probably cache it? |
330 | | - auto cpuInterface = createDeviceInterface(torch::Device(torch::kCPU)); |
| 346 | + // TODO: Perhaps we should cache cpuInterface? |
| 347 | + auto cpuInterface = std::make_unique<CpuDeviceInterface>(torch::kCPU); |
331 | 348 | TORCH_CHECK( |
332 | 349 | cpuInterface != nullptr, "Failed to create CPU device interface"); |
333 | 350 | cpuInterface->initialize( |
334 | 351 | nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_); |
335 | 352 |
|
| 353 | + enum AVPixelFormat frameFormat = |
| 354 | + static_cast<enum AVPixelFormat>(avFrame->format); |
| 355 | + |
336 | 356 | FrameOutput cpuFrameOutput; |
337 | | - cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); |
338 | 357 |
|
339 | | - // TODO: explain that the pre-allocated tensor is on the GPU, but we need |
340 | | - // to do the decoding on the CPU, and we can't pass the pre-allocated tensor |
341 | | - // to do it. BUT WHY did it work before? |
| 358 | + if (frameFormat == AV_PIX_FMT_RGB24 && |
| 359 | + avFrame->width == outputDims_.width && |
| 360 | + avFrame->height == outputDims_.height) { |
| 361 | + // Reason 1 above. The frame is already in the format and dimensions that |
| 362 | + // we need, we just need to convert it to a tensor. |
| 363 | + cpuFrameOutput.data = cpuInterface->toTensor(avFrame); |
| 364 | + } else { |
| 365 | + // Reason 2 above. We need to do a full conversion. |
| 366 | + cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); |
| 367 | + } |
| 368 | + |
| 369 | + // Finally, we need to send the frame back to the GPU. Note that the |
| 370 | + // pre-allocated tensor is on the GPU, so we can't send that to the CPU |
| 371 | + // device interface. We copy it over here. |
342 | 372 | if (preAllocatedOutputTensor.has_value()) { |
343 | 373 | preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data); |
344 | 374 | frameOutput.data = preAllocatedOutputTensor.value(); |
@@ -372,16 +402,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( |
372 | 402 | torch::Tensor& dst = frameOutput.data; |
373 | 403 | if (preAllocatedOutputTensor.has_value()) { |
374 | 404 | dst = preAllocatedOutputTensor.value(); |
375 | | - auto shape = dst.sizes(); |
376 | | - TORCH_CHECK( |
377 | | - (shape.size() == 3) && (shape[0] == outputDims_.height) && |
378 | | - (shape[1] == outputDims_.width) && (shape[2] == 3), |
379 | | - "Expected tensor of shape ", |
380 | | - outputDims_.height, |
381 | | - "x", |
382 | | - outputDims_.width, |
383 | | - "x3, got ", |
384 | | - shape); |
385 | 405 | } else { |
386 | 406 | dst = allocateEmptyHWCTensor(outputDims_, device_); |
387 | 407 | } |
|
0 commit comments