Skip to content

Commit 11d7adf

Browse files
author
Molly Xu
committed
use shared_ptr for codecContext
1 parent 73fe68b commit 11d7adf

11 files changed

+54
-41
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,10 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
231231

232232
void BetaCudaDeviceInterface::initialize(
233233
const AVStream* avStream,
234-
const UniqueDecodingAVFormatContext& avFormatCtx) {
234+
const UniqueDecodingAVFormatContext& avFormatCtx,
235+
const SharedAVCodecContext& codecContext) {
235236
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
237+
codecContext_ = codecContext;
236238
timeBase_ = avStream->time_base;
237239
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
238240

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
4040

4141
void initialize(
4242
const AVStream* avStream,
43-
const UniqueDecodingAVFormatContext& avFormatCtx) override;
43+
const UniqueDecodingAVFormatContext& avFormatCtx,
44+
const SharedAVCodecContext& codecContext) override;
4445

4546
void convertAVFrameToFrameOutput(
4647
UniqueAVFrame& avFrame,

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4848

4949
void CpuDeviceInterface::initialize(
5050
const AVStream* avStream,
51-
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) {
51+
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx,
52+
const SharedAVCodecContext& codecContext) {
5253
TORCH_CHECK(avStream != nullptr, "avStream is null");
54+
codecContext_ = codecContext;
5355
timeBase_ = avStream->time_base;
5456
}
5557

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class CpuDeviceInterface : public DeviceInterface {
2525

2626
virtual void initialize(
2727
const AVStream* avStream,
28-
const UniqueDecodingAVFormatContext& avFormatCtx) override;
28+
const UniqueDecodingAVFormatContext& avFormatCtx,
29+
const SharedAVCodecContext& codecContext) override;
2930

3031
virtual void initializeVideo(
3132
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,17 @@ CudaDeviceInterface::~CudaDeviceInterface() {
117117

118118
void CudaDeviceInterface::initialize(
119119
const AVStream* avStream,
120-
const UniqueDecodingAVFormatContext& avFormatCtx) {
120+
const UniqueDecodingAVFormatContext& avFormatCtx,
121+
const SharedAVCodecContext& codecContext) {
121122
TORCH_CHECK(avStream != nullptr, "avStream is null");
123+
codecContext_ = codecContext;
122124
timeBase_ = avStream->time_base;
123125

124126
// TODO: Ideally, we should keep all interface implementations independent.
125127
cpuInterface_ = createDeviceInterface(torch::kCPU);
126128
TORCH_CHECK(
127129
cpuInterface_ != nullptr, "Failed to create CPU device interface");
128-
cpuInterface_->initialize(avStream, avFormatCtx);
130+
cpuInterface_->initialize(avStream, avFormatCtx, codecContext);
129131
cpuInterface_->initializeVideo(
130132
VideoStreamOptions(),
131133
{},

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class CudaDeviceInterface : public DeviceInterface {
2222

2323
void initialize(
2424
const AVStream* avStream,
25-
const UniqueDecodingAVFormatContext& avFormatCtx) override;
25+
const UniqueDecodingAVFormatContext& avFormatCtx,
26+
const SharedAVCodecContext& codecContext) override;
2627

2728
void initializeVideo(
2829
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class DeviceInterface {
5454
// Initialize the device with parameters generic to all kinds of decoding.
5555
virtual void initialize(
5656
const AVStream* avStream,
57-
const UniqueDecodingAVFormatContext& avFormatCtx) = 0;
57+
const UniqueDecodingAVFormatContext& avFormatCtx,
58+
const SharedAVCodecContext& codecContext) = 0;
5859

5960
// Initialize the device with parameters specific to video decoding. There is
6061
// a default empty implementation.
@@ -80,23 +81,14 @@ class DeviceInterface {
8081
// Extension points for custom decoding paths
8182
// ------------------------------------------
8283

83-
// Set the codec context for default FFmpeg decoding operations
84-
// This must be called during initialization before using
85-
// sendPacket/receiveFrame
86-
virtual void setCodecContext(AVCodecContext* codecContext) {
87-
codecContext_ = codecContext;
88-
}
89-
9084
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or
9185
// other AVERROR on failure
9286
// Default implementation uses FFmpeg directly
9387
virtual int sendPacket(ReferenceAVPacket& avPacket) {
94-
if (!codecContext_) {
95-
TORCH_CHECK(
96-
false, "Codec context not available for default packet sending");
97-
return AVERROR(EINVAL);
98-
}
99-
return avcodec_send_packet(codecContext_, avPacket.get());
88+
TORCH_CHECK(
89+
codecContext_ != nullptr,
90+
"Codec context not available for default packet sending");
91+
return avcodec_send_packet(codecContext_.get(), avPacket.get());
10092
}
10193

10294
// Send an EOF packet to flush the decoder
@@ -107,29 +99,30 @@ class DeviceInterface {
10799
TORCH_CHECK(false, "Codec context not available for EOF packet sending");
108100
return AVERROR(EINVAL);
109101
}
110-
return avcodec_send_packet(codecContext_, nullptr);
102+
return avcodec_send_packet(codecContext_.get(), nullptr);
111103
}
112104

113105
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
114106
// AVERROR_EOF if end of stream, or other AVERROR on failure
115107
// Default implementation uses FFmpeg directly
116108
virtual int receiveFrame(UniqueAVFrame& avFrame) {
117-
if (!codecContext_) {
118-
TORCH_CHECK(false, "Codec context not available for frame receiving");
119-
return AVERROR(EINVAL);
120-
}
121-
return avcodec_receive_frame(codecContext_, avFrame.get());
109+
TORCH_CHECK(
110+
codecContext_ != nullptr,
111+
"Codec context not available for default frame receiving");
112+
return avcodec_receive_frame(codecContext_.get(), avFrame.get());
122113
}
123114

124115
// Flush remaining frames from decoder
125116
virtual void flush() {
126-
// Default implementation is no-op for standard decoders
127-
// Custom decoders can override this method
117+
TORCH_CHECK(
118+
codecContext_ != nullptr,
119+
"Codec context not available for default flushing");
120+
avcodec_flush_buffers(codecContext_.get());
128121
}
129122

130123
protected:
131124
torch::Device device_;
132-
AVCodecContext* codecContext_ = nullptr; // Non-owning pointer
125+
SharedAVCodecContext codecContext_;
133126
};
134127

135128
using CreateDeviceInterfaceFn =

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,15 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
158158
#endif
159159
}
160160

161+
int getNumChannels(const SharedAVCodecContext& avCodecContext) {
162+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
163+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
164+
return avCodecContext->ch_layout.nb_channels;
165+
#else
166+
return avCodecContext->channels;
167+
#endif
168+
}
169+
161170
void setDefaultChannelLayout(
162171
UniqueAVCodecContext& avCodecContext,
163172
int numChannels) {

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ using UniqueEncodingAVFormatContext = std::unique_ptr<
7171
using UniqueAVCodecContext = std::unique_ptr<
7272
AVCodecContext,
7373
Deleterp<AVCodecContext, void, avcodec_free_context>>;
74+
using SharedAVCodecContext = std::shared_ptr<AVCodecContext>;
75+
7476
using UniqueAVFrame =
7577
std::unique_ptr<AVFrame, Deleterp<AVFrame, void, av_frame_free>>;
7678
using UniqueAVFilterGraph = std::unique_ptr<
@@ -172,6 +174,7 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
172174

173175
int getNumChannels(const UniqueAVFrame& avFrame);
174176
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
177+
int getNumChannels(const SharedAVCodecContext& avCodecContext);
175178

176179
void setDefaultChannelLayout(
177180
UniqueAVCodecContext& avCodecContext,

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ void SingleStreamDecoder::addStream(
429429
TORCH_CHECK(
430430
deviceInterface_ != nullptr,
431431
"Failed to create device interface. This should never happen, please report.");
432-
deviceInterface_->initialize(streamInfo.stream, formatContext_);
433432

434433
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
435434
// addStream() which is supposed to be generic
@@ -441,7 +440,8 @@ void SingleStreamDecoder::addStream(
441440

442441
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
443442
TORCH_CHECK(codecContext != nullptr);
444-
streamInfo.codecContext.reset(codecContext);
443+
streamInfo.codecContext = SharedAVCodecContext(
444+
codecContext, [](AVCodecContext* ctx) { avcodec_free_context(&ctx); });
445445

446446
int retVal = avcodec_parameters_to_context(
447447
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
@@ -453,18 +453,19 @@ void SingleStreamDecoder::addStream(
453453
// Note that we must make sure to register the harware device context
454454
// with the codec context before calling avcodec_open2(). Otherwise, decoding
455455
// will happen on the CPU and not the hardware device.
456-
deviceInterface_->registerHardwareDeviceWithCodec(codecContext);
456+
deviceInterface_->registerHardwareDeviceWithCodec(
457+
streamInfo.codecContext.get());
457458
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
458459
TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
459460

460-
codecContext->time_base = streamInfo.stream->time_base;
461+
streamInfo.codecContext->time_base = streamInfo.stream->time_base;
461462

462-
// Set the codec context on the device interface for default FFmpeg
463-
// implementations
464-
deviceInterface_->setCodecContext(codecContext);
463+
// Initialize the device interface with the codec context
464+
deviceInterface_->initialize(
465+
streamInfo.stream, formatContext_, streamInfo.codecContext);
465466

466467
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
467-
std::string(avcodec_get_name(codecContext->codec_id));
468+
std::string(avcodec_get_name(streamInfo.codecContext->codec_id));
468469

469470
// We will only need packets from the active stream, so we tell FFmpeg to
470471
// discard packets from the other streams. Note that av_read_frame() may still
@@ -1153,8 +1154,6 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
11531154
getFFMPEGErrorStringFromErrorCode(status));
11541155

11551156
decodeStats_.numFlushes++;
1156-
avcodec_flush_buffers(streamInfo.codecContext.get());
1157-
11581157
deviceInterface_->flush();
11591158
}
11601159

0 commit comments

Comments
 (0)