Skip to content

Commit 926b7ea

Browse files
committed
separate encoding frame ctx init
1 parent 88e1299 commit 926b7ea

File tree

5 files changed

+18
-1
lines changed

5 files changed

+18
-1
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ void CudaDeviceInterface::registerHardwareDeviceWithCodec(
144144
hardwareDeviceCtx_, "Hardware device context has not been initialized");
145145
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
146146
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
147+
}
148+
149+
void CudaDeviceInterface::setupEncodingContext(AVCodecContext* codecContext) {
150+
TORCH_CHECK(
151+
hardwareDeviceCtx_, "Hardware device context has not been initialized");
152+
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
147153
// is there any way to preserve actual desired format?
148154
// codecContext->sw_pix_fmt = codecContext->pix_fmt;
149155
// Should we always produce AV_PIX_FMT_NV12?

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class CudaDeviceInterface : public DeviceInterface {
3535

3636
void registerHardwareDeviceWithCodec(AVCodecContext* codecContext) override;
3737

38+
void setupEncodingContext(AVCodecContext* codecContext) override;
39+
3840
void convertAVFrameToFrameOutput(
3941
UniqueAVFrame& avFrame,
4042
FrameOutput& frameOutput,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ class DeviceInterface {
9292
virtual void registerHardwareDeviceWithCodec(
9393
[[maybe_unused]] AVCodecContext* codecContext) {}
9494

95+
// Setup device-specific encoding context (e.g., hardware frame contexts).
96+
// Called after registerHardwareDeviceWithCodec for encoders.
97+
// Default implementation does nothing (suitable for CPU and basic cases).
98+
virtual void setupEncodingContext(
99+
[[maybe_unused]] AVCodecContext* codecContext) {}
100+
95101
virtual void convertAVFrameToFrameOutput(
96102
UniqueAVFrame& avFrame,
97103
FrameOutput& frameOutput,

src/torchcodec/_core/Encoder.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,9 @@ void VideoEncoder::initializeEncoder(
850850
// context before calling avcodec_open2().
851851
deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get());
852852

853+
// Setup device-specific encoding context (e.g., hardware frame contexts)
854+
deviceInterface_->setupEncodingContext(avCodecContext_.get());
855+
853856
int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions);
854857
av_dict_free(&avCodecOptions);
855858

test/test_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def test_contiguity(self, method, tmp_path, device):
775775
num_frames, channels, height, width = 5, 3, 256, 256
776776
contiguous_frames = torch.randint(
777777
0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8
778-
).contiguous()
778+
).contiguous().to(device)
779779
assert contiguous_frames.is_contiguous()
780780

781781
# Permute NCHW to NHWC, then update the memory layout, then permute back

0 commit comments

Comments
 (0)