Skip to content

Commit 340974a

Browse files
committed
Docs
1 parent f4c8f4e commit 340974a

File tree

3 files changed

+55
-42
lines changed

3 files changed

+55
-42
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ bool nativeNVDECSupport(const SharedAVCodecContext& codecContext) {
212212
return true;
213213
}
214214

215+
// Callback for freeing CUDA memory associated with AVFrame see where it's used
216+
// for more details.
217+
void cudaBufferFreeCallback(void* opaque, [[maybe_unused]] uint8_t* data) {
218+
cudaFree(opaque);
219+
}
220+
215221
} // namespace
216222

217223
BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
@@ -665,20 +671,23 @@ void BetaCudaDeviceInterface::flush() {
665671
std::swap(readyFrames_, emptyQueue);
666672
}
667673

668-
namespace {
669-
// Cleanup callback for CUDA memory allocated for GPU frames
670-
void cudaBufferFreeCallback(void* opaque, [[maybe_unused]] uint8_t* data) {
671-
cudaFree(opaque);
672-
}
673-
} // namespace
674-
675674
UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
676675
UniqueAVFrame& cpuFrame) {
676+
// This is called in the context of the CPU fallback: the frame was decoded on
677+
// the CPU, and in this function we convert that frame into NV12 format and
678+
// send it to the GPU.
679+
// We do that in 2 steps:
680+
// - First we convert the input CPU frame into an intermediate NV12 CPU frame
681+
// using sws_scale.
682+
// - Then we allocate GPU memory and copy the NV12 CPU frame to the GPU. This
683+
// is what we return
684+
677685
TORCH_CHECK(cpuFrame != nullptr, "CPU frame cannot be null");
678686

679687
int width = cpuFrame->width;
680688
int height = cpuFrame->height;
681689

690+
// intermediate NV12 CPU frame. It's not on the GPU yet.
682691
UniqueAVFrame nv12CpuFrame(av_frame_alloc());
683692
TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame");
684693

@@ -707,7 +716,7 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
707716

708717
int convertedHeight = sws_scale(
709718
swsContext_.get(),
710-
const_cast<const uint8_t* const*>(cpuFrame->data),
719+
cpuFrame->data,
711720
cpuFrame->linesize,
712721
0,
713722
height,
@@ -739,6 +748,9 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
739748
gpuFrame->linesize[0] = width;
740749
gpuFrame->linesize[1] = width;
741750

751+
// Note that we use cudaMemcpy2D here instead of cudaMemcpy because the
752+
// linesizes (strides) may be different than the widths for the input CPU
753+
// frame. That's precisely what cudaMemcpy2D is for.
742754
err = cudaMemcpy2D(
743755
gpuFrame->data[0],
744756
gpuFrame->linesize[0],
@@ -771,10 +783,16 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
771783
"Failed to copy frame properties: ",
772784
getFFMPEGErrorStringFromErrorCode(ret));
773785

786+
// We're almost done, but we need to make sure the CUDA memory is freed
787+
// properly. Usually, AVFrame data is freed when av_frame_free() is called
788+
// (upon UniqueAVFrame destruction), but since we allocated the CUDA memory
789+
// ourselves, FFmpeg doesn't know how to free it. The recommended way to deal
790+
// with this is to associate the opaque_ref field of the AVFrame with a `free`
791+
// callback that will then be called by av_frame_free().
774792
gpuFrame->opaque_ref = av_buffer_create(
775-
nullptr, // data
793+
nullptr, // data - we don't need any
776794
0, // data size
777-
cudaBufferFreeCallback, // callback triggered by av_frame_free()
795+
cudaBufferFreeCallback, // callback triggered by av_frame_free()
778796
cudaBuffer, // parameter to callback
779797
0); // flags
780798
TORCH_CHECK(

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ class BetaCudaDeviceInterface : public DeviceInterface {
8181
unsigned int pitch,
8282
const CUVIDPARSERDISPINFO& dispInfo);
8383

84-
// Convert CPU frame to GPU NV12 frame for GPU color conversion
8584
UniqueAVFrame transferCpuFrameToGpuNV12(UniqueAVFrame& cpuFrame);
8685

8786
CUvideoparser videoParser_ = nullptr;
@@ -100,11 +99,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
10099
// NPP context for color conversion
101100
UniqueNppContext nppCtx_;
102101

103-
// Swscale context caching for CPU->GPU NV12 conversion
102+
std::unique_ptr<DeviceInterface> cpuFallback_;
104103
UniqueSwsContext swsContext_;
105104
SwsFrameContext prevSwsFrameContext_;
106-
107-
std::unique_ptr<DeviceInterface> cpuFallback_;
108105
};
109106

110107
} // namespace facebook::torchcodec

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -104,34 +104,6 @@ using UniqueAVBufferSrcParameters = std::unique_ptr<
104104
AVBufferSrcParameters,
105105
Deleterv<AVBufferSrcParameters, void, av_freep>>;
106106

107-
// Common swscale context management for efficient reuse across device
108-
// interfaces
109-
struct SwsFrameContext {
110-
int inputWidth = 0;
111-
int inputHeight = 0;
112-
AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
113-
int outputWidth = 0;
114-
int outputHeight = 0;
115-
116-
SwsFrameContext() = default;
117-
SwsFrameContext(
118-
int inputWidth,
119-
int inputHeight,
120-
AVPixelFormat inputFormat,
121-
int outputWidth,
122-
int outputHeight);
123-
124-
bool operator==(const SwsFrameContext& other) const;
125-
bool operator!=(const SwsFrameContext& other) const;
126-
};
127-
128-
// Utility functions for swscale context management
129-
UniqueSwsContext createSwsContext(
130-
const SwsFrameContext& swsFrameContext,
131-
AVColorSpace colorspace,
132-
AVPixelFormat outputFormat = AV_PIX_FMT_RGB24,
133-
int swsFlags = SWS_BILINEAR);
134-
135107
// These 2 classes share the same underlying AVPacket object. They are meant to
136108
// be used in tandem, like so:
137109
//
@@ -279,4 +251,30 @@ AVFilterContext* createBuffersinkFilter(
279251
AVFilterGraph* filterGraph,
280252
enum AVPixelFormat outputFormat);
281253

254+
struct SwsFrameContext {
255+
int inputWidth = 0;
256+
int inputHeight = 0;
257+
AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
258+
int outputWidth = 0;
259+
int outputHeight = 0;
260+
261+
SwsFrameContext() = default;
262+
SwsFrameContext(
263+
int inputWidth,
264+
int inputHeight,
265+
AVPixelFormat inputFormat,
266+
int outputWidth,
267+
int outputHeight);
268+
269+
bool operator==(const SwsFrameContext& other) const;
270+
bool operator!=(const SwsFrameContext& other) const;
271+
};
272+
273+
// Utility functions for swscale context management
274+
UniqueSwsContext createSwsContext(
275+
const SwsFrameContext& swsFrameContext,
276+
AVColorSpace colorspace,
277+
AVPixelFormat outputFormat = AV_PIX_FMT_RGB24,
278+
int swsFlags = SWS_BILINEAR);
279+
282280
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)