Skip to content

Commit 7dd12c1

Browse files
authored
Merge branch 'main' into generalize-deviceInterface
2 parents 159672d + d63504c commit 7dd12c1

File tree

12 files changed

+472
-178
lines changed

12 files changed

+472
-178
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 148 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,12 @@ bool nativeNVDECSupport(const SharedAVCodecContext& codecContext) {
213213
return true;
214214
}
215215

216+
// Callback for freeing CUDA memory associated with AVFrame see where it's used
217+
// for more details.
218+
void cudaBufferFreeCallback(void* opaque, [[maybe_unused]] uint8_t* data) {
219+
cudaFree(opaque);
220+
}
221+
216222
} // namespace
217223

218224
BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
@@ -668,40 +674,164 @@ void BetaCudaDeviceInterface::flush() {
668674
std::swap(readyFrames_, emptyQueue);
669675
}
670676

677+
UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
678+
UniqueAVFrame& cpuFrame) {
679+
// This is called in the context of the CPU fallback: the frame was decoded on
680+
// the CPU, and in this function we convert that frame into NV12 format and
681+
// send it to the GPU.
682+
// We do that in 2 steps:
683+
// - First we convert the input CPU frame into an intermediate NV12 CPU frame
684+
// using sws_scale.
685+
// - Then we allocate GPU memory and copy the NV12 CPU frame to the GPU. This
686+
// is what we return
687+
688+
TORCH_CHECK(cpuFrame != nullptr, "CPU frame cannot be null");
689+
690+
int width = cpuFrame->width;
691+
int height = cpuFrame->height;
692+
693+
// intermediate NV12 CPU frame. It's not on the GPU yet.
694+
UniqueAVFrame nv12CpuFrame(av_frame_alloc());
695+
TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame");
696+
697+
nv12CpuFrame->format = AV_PIX_FMT_NV12;
698+
nv12CpuFrame->width = width;
699+
nv12CpuFrame->height = height;
700+
701+
int ret = av_frame_get_buffer(nv12CpuFrame.get(), 0);
702+
TORCH_CHECK(
703+
ret >= 0,
704+
"Failed to allocate NV12 CPU frame buffer: ",
705+
getFFMPEGErrorStringFromErrorCode(ret));
706+
707+
SwsFrameContext swsFrameContext(
708+
width,
709+
height,
710+
static_cast<AVPixelFormat>(cpuFrame->format),
711+
width,
712+
height);
713+
714+
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
715+
swsContext_ = createSwsContext(
716+
swsFrameContext, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR);
717+
prevSwsFrameContext_ = swsFrameContext;
718+
}
719+
720+
int convertedHeight = sws_scale(
721+
swsContext_.get(),
722+
cpuFrame->data,
723+
cpuFrame->linesize,
724+
0,
725+
height,
726+
nv12CpuFrame->data,
727+
nv12CpuFrame->linesize);
728+
TORCH_CHECK(
729+
convertedHeight == height, "sws_scale failed for CPU->NV12 conversion");
730+
731+
int ySize = width * height;
732+
TORCH_CHECK(
733+
ySize % 2 == 0,
734+
"Y plane size must be even. Please report on TorchCodec repo.");
735+
int uvSize = ySize / 2; // NV12: UV plane is half the size of Y plane
736+
size_t totalSize = static_cast<size_t>(ySize + uvSize);
737+
738+
uint8_t* cudaBuffer = nullptr;
739+
cudaError_t err =
740+
cudaMalloc(reinterpret_cast<void**>(&cudaBuffer), totalSize);
741+
TORCH_CHECK(
742+
err == cudaSuccess,
743+
"Failed to allocate CUDA memory: ",
744+
cudaGetErrorString(err));
745+
746+
UniqueAVFrame gpuFrame(av_frame_alloc());
747+
TORCH_CHECK(gpuFrame != nullptr, "Failed to allocate GPU AVFrame");
748+
749+
gpuFrame->format = AV_PIX_FMT_CUDA;
750+
gpuFrame->width = width;
751+
gpuFrame->height = height;
752+
gpuFrame->data[0] = cudaBuffer;
753+
gpuFrame->data[1] = cudaBuffer + ySize;
754+
gpuFrame->linesize[0] = width;
755+
gpuFrame->linesize[1] = width;
756+
757+
// Note that we use cudaMemcpy2D here instead of cudaMemcpy because the
758+
// linesizes (strides) may be different than the widths for the input CPU
759+
// frame. That's precisely what cudaMemcpy2D is for.
760+
err = cudaMemcpy2D(
761+
gpuFrame->data[0],
762+
gpuFrame->linesize[0],
763+
nv12CpuFrame->data[0],
764+
nv12CpuFrame->linesize[0],
765+
width,
766+
height,
767+
cudaMemcpyHostToDevice);
768+
TORCH_CHECK(
769+
err == cudaSuccess,
770+
"Failed to copy Y plane to GPU: ",
771+
cudaGetErrorString(err));
772+
773+
TORCH_CHECK(
774+
height % 2 == 0,
775+
"height must be even. Please report on TorchCodec repo.");
776+
err = cudaMemcpy2D(
777+
gpuFrame->data[1],
778+
gpuFrame->linesize[1],
779+
nv12CpuFrame->data[1],
780+
nv12CpuFrame->linesize[1],
781+
width,
782+
height / 2,
783+
cudaMemcpyHostToDevice);
784+
TORCH_CHECK(
785+
err == cudaSuccess,
786+
"Failed to copy UV plane to GPU: ",
787+
cudaGetErrorString(err));
788+
789+
ret = av_frame_copy_props(gpuFrame.get(), cpuFrame.get());
790+
TORCH_CHECK(
791+
ret >= 0,
792+
"Failed to copy frame properties: ",
793+
getFFMPEGErrorStringFromErrorCode(ret));
794+
795+
// We're almost done, but we need to make sure the CUDA memory is freed
796+
// properly. Usually, AVFrame data is freed when av_frame_free() is called
797+
// (upon UniqueAVFrame destruction), but since we allocated the CUDA memory
798+
// ourselves, FFmpeg doesn't know how to free it. The recommended way to deal
799+
// with this is to associate the opaque_ref field of the AVFrame with a `free`
800+
// callback that will then be called by av_frame_free().
801+
gpuFrame->opaque_ref = av_buffer_create(
802+
nullptr, // data - we don't need any
803+
0, // data size
804+
cudaBufferFreeCallback, // callback triggered by av_frame_free()
805+
cudaBuffer, // parameter to callback
806+
0); // flags
807+
TORCH_CHECK(
808+
gpuFrame->opaque_ref != nullptr,
809+
"Failed to create GPU memory cleanup reference");
810+
811+
return gpuFrame;
812+
}
813+
671814
void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
672815
UniqueAVFrame& avFrame,
673816
FrameOutput& frameOutput,
674817
[[maybe_unused]] AVMediaType mediaType,
675818
std::optional<torch::Tensor> preAllocatedOutputTensor) {
676-
if (cpuFallback_) {
677-
// CPU decoded frame - need to do CPU color conversion then transfer to GPU
678-
FrameOutput cpuFrameOutput;
679-
cpuFallback_->convertAVFrameToFrameOutput(
680-
avFrame, cpuFrameOutput, AVMEDIA_TYPE_VIDEO);
681-
682-
// Transfer CPU frame to GPU
683-
if (preAllocatedOutputTensor.has_value()) {
684-
preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data);
685-
frameOutput.data = preAllocatedOutputTensor.value();
686-
} else {
687-
frameOutput.data = cpuFrameOutput.data.to(device_);
688-
}
689-
return;
690-
}
819+
UniqueAVFrame gpuFrame =
820+
cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame) : std::move(avFrame);
691821

692822
// TODONVDEC P2: we may need to handle 10bit videos the same way the CUDA
693823
// ffmpeg interface does it with maybeConvertAVFrameToNV12OrRGB24().
694824
TORCH_CHECK(
695-
avFrame->format == AV_PIX_FMT_CUDA,
825+
gpuFrame->format == AV_PIX_FMT_CUDA,
696826
"Expected CUDA format frame from BETA CUDA interface");
697827

698-
validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame);
828+
validatePreAllocatedTensorShape(preAllocatedOutputTensor, gpuFrame);
699829

700830
at::cuda::CUDAStream nvdecStream =
701831
at::cuda::getCurrentCUDAStream(device_.index());
702832

703833
frameOutput.data = convertNV12FrameToRGB(
704-
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
834+
gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
705835
}
706836

707837
std::string BetaCudaDeviceInterface::getDetails() {

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
8282
unsigned int pitch,
8383
const CUVIDPARSERDISPINFO& dispInfo);
8484

85+
UniqueAVFrame transferCpuFrameToGpuNV12(UniqueAVFrame& cpuFrame);
86+
8587
CUvideoparser videoParser_ = nullptr;
8688
UniqueCUvideodecoder decoder_;
8789
CUVIDEOFORMAT videoFormat_ = {};
@@ -100,6 +102,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
100102

101103
std::unique_ptr<DeviceInterface> cpuFallback_;
102104
bool nvcuvidAvailable_ = false;
105+
UniqueSwsContext swsContext_;
106+
SwsFrameContext prevSwsFrameContext_;
103107
};
104108

105109
} // namespace facebook::torchcodec

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,6 @@ static bool g_cpu = registerDeviceInterface(
1515

1616
} // namespace
1717

18-
CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
19-
int inputWidth,
20-
int inputHeight,
21-
AVPixelFormat inputFormat,
22-
int outputWidth,
23-
int outputHeight)
24-
: inputWidth(inputWidth),
25-
inputHeight(inputHeight),
26-
inputFormat(inputFormat),
27-
outputWidth(outputWidth),
28-
outputHeight(outputHeight) {}
29-
30-
bool CpuDeviceInterface::SwsFrameContext::operator==(
31-
const CpuDeviceInterface::SwsFrameContext& other) const {
32-
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
33-
inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
34-
outputHeight == other.outputHeight;
35-
}
36-
37-
bool CpuDeviceInterface::SwsFrameContext::operator!=(
38-
const CpuDeviceInterface::SwsFrameContext& other) const {
39-
return !(*this == other);
40-
}
41-
4218
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4319
: DeviceInterface(device) {
4420
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
@@ -276,7 +252,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
276252
outputDims.height);
277253

278254
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
279-
createSwsContext(swsFrameContext, avFrame->colorspace);
255+
swsContext_ = createSwsContext(
256+
swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_);
280257
prevSwsFrameContext_ = swsFrameContext;
281258
}
282259

@@ -295,51 +272,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
295272
return resultHeight;
296273
}
297274

298-
void CpuDeviceInterface::createSwsContext(
299-
const SwsFrameContext& swsFrameContext,
300-
const enum AVColorSpace colorspace) {
301-
SwsContext* swsContext = sws_getContext(
302-
swsFrameContext.inputWidth,
303-
swsFrameContext.inputHeight,
304-
swsFrameContext.inputFormat,
305-
swsFrameContext.outputWidth,
306-
swsFrameContext.outputHeight,
307-
AV_PIX_FMT_RGB24,
308-
swsFlags_,
309-
nullptr,
310-
nullptr,
311-
nullptr);
312-
TORCH_CHECK(swsContext, "sws_getContext() returned nullptr");
313-
314-
int* invTable = nullptr;
315-
int* table = nullptr;
316-
int srcRange, dstRange, brightness, contrast, saturation;
317-
int ret = sws_getColorspaceDetails(
318-
swsContext,
319-
&invTable,
320-
&srcRange,
321-
&table,
322-
&dstRange,
323-
&brightness,
324-
&contrast,
325-
&saturation);
326-
TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1");
327-
328-
const int* colorspaceTable = sws_getCoefficients(colorspace);
329-
ret = sws_setColorspaceDetails(
330-
swsContext,
331-
colorspaceTable,
332-
srcRange,
333-
colorspaceTable,
334-
dstRange,
335-
brightness,
336-
contrast,
337-
saturation);
338-
TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1");
339-
340-
swsContext_.reset(swsContext);
341-
}
342-
343275
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
344276
const UniqueAVFrame& avFrame,
345277
const FrameDims& outputDims) {

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,6 @@ class CpuDeviceInterface : public DeviceInterface {
6969
ColorConversionLibrary getColorConversionLibrary(
7070
const FrameDims& inputFrameDims) const;
7171

72-
struct SwsFrameContext {
73-
int inputWidth = 0;
74-
int inputHeight = 0;
75-
AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
76-
int outputWidth = 0;
77-
int outputHeight = 0;
78-
79-
SwsFrameContext() = default;
80-
SwsFrameContext(
81-
int inputWidth,
82-
int inputHeight,
83-
AVPixelFormat inputFormat,
84-
int outputWidth,
85-
int outputHeight);
86-
bool operator==(const SwsFrameContext&) const;
87-
bool operator!=(const SwsFrameContext&) const;
88-
};
89-
90-
void createSwsContext(
91-
const SwsFrameContext& swsFrameContext,
92-
const enum AVColorSpace colorspace);
93-
9472
VideoStreamOptions videoStreamOptions_;
9573
AVRational timeBase_;
9674

src/torchcodec/_core/Encoder.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,6 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
531531
frames.sizes()[1] == 3,
532532
"frame must have 3 channels (R, G, B), got ",
533533
frames.sizes()[1]);
534-
// TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
535534
return frames.contiguous();
536535
}
537536

0 commit comments

Comments
 (0)