Skip to content

Commit 718c268

Browse files
author
Molly Xu
committed
Refactor receiveFrame and sendPacket logic to dispatch directly to interface
1 parent e5b2eef commit 718c268

File tree

3 files changed

+37
-50
lines changed

3 files changed

+37
-50
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ class BetaCudaDeviceInterface : public DeviceInterface {
4848
std::optional<torch::Tensor> preAllocatedOutputTensor =
4949
std::nullopt) override;
5050

51-
bool canDecodePacketDirectly() const override {
52-
return true;
53-
}
5451

5552
int sendPacket(ReferenceAVPacket& packet) override;
5653
int sendEOFPacket() override;

src/torchcodec/_core/DeviceInterface.h

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -80,42 +80,45 @@ class DeviceInterface {
8080
// Extension points for custom decoding paths
8181
// ------------------------------------------
8282

83-
// Override to return true if this device interface can decode packets
84-
// directly. This means that the following two member functions can both
85-
// be called:
86-
//
87-
// 1. sendPacket()
88-
// 2. receiveFrame()
89-
virtual bool canDecodePacketDirectly() const {
90-
return false;
83+
// Set the codec context for default FFmpeg decoding operations
84+
// This must be called during initialization before using
85+
// sendPacket/receiveFrame
86+
virtual void setCodecContext(AVCodecContext* codecContext) {
87+
codecContext_ = codecContext;
9188
}
9289

93-
// Moral equivalent of avcodec_send_packet()
9490
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or
9591
// other AVERROR on failure
96-
virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) {
97-
TORCH_CHECK(
98-
false,
99-
"Send/receive packet decoding not implemented for this device interface");
100-
return AVERROR(ENOSYS);
92+
// Default implementation uses FFmpeg directly
93+
virtual int sendPacket(ReferenceAVPacket& avPacket) {
94+
if (!codecContext_) {
95+
TORCH_CHECK(
96+
false, "Codec context not available for default packet sending");
97+
return AVERROR(EINVAL);
98+
}
99+
return avcodec_send_packet(codecContext_, avPacket.get());
101100
}
102101

103102
// Send an EOF packet to flush the decoder
104103
// Returns AVSUCCESS on success, or other AVERROR on failure
104+
// Default implementation uses FFmpeg directly
105105
virtual int sendEOFPacket() {
106-
TORCH_CHECK(
107-
false, "Send EOF packet not implemented for this device interface");
108-
return AVERROR(ENOSYS);
106+
if (!codecContext_) {
107+
TORCH_CHECK(false, "Codec context not available for EOF packet sending");
108+
return AVERROR(EINVAL);
109+
}
110+
return avcodec_send_packet(codecContext_, nullptr);
109111
}
110112

111-
// Moral equivalent of avcodec_receive_frame()
112113
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
113114
// AVERROR_EOF if end of stream, or other AVERROR on failure
114-
virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) {
115-
TORCH_CHECK(
116-
false,
117-
"Send/receive packet decoding not implemented for this device interface");
118-
return AVERROR(ENOSYS);
115+
// Default implementation uses FFmpeg directly
116+
virtual int receiveFrame(UniqueAVFrame& avFrame) {
117+
if (!codecContext_) {
118+
TORCH_CHECK(false, "Codec context not available for frame receiving");
119+
return AVERROR(EINVAL);
120+
}
121+
return avcodec_receive_frame(codecContext_, avFrame.get());
119122
}
120123

121124
// Flush remaining frames from decoder
@@ -126,6 +129,7 @@ class DeviceInterface {
126129

127130
protected:
128131
torch::Device device_;
132+
AVCodecContext* codecContext_ = nullptr; // Non-owning pointer
129133
};
130134

131135
using CreateDeviceInterfaceFn =

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,10 @@ void SingleStreamDecoder::addStream(
459459

460460
codecContext->time_base = streamInfo.stream->time_base;
461461

462+
// Set the codec context on the device interface for default FFmpeg
463+
// implementations
464+
deviceInterface_->setCodecContext(codecContext);
465+
462466
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
463467
std::string(avcodec_get_name(codecContext->codec_id));
464468

@@ -1169,24 +1173,16 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11691173
cursorWasJustSet_ = false;
11701174
}
11711175

1172-
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
11731176
UniqueAVFrame avFrame(av_frame_alloc());
11741177
AutoAVPacket autoAVPacket;
11751178
int status = AVSUCCESS;
11761179
bool reachedEOF = false;
11771180

1178-
// TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on
1179-
// if/else blocks to dispatch to the interface or to FFmpeg, consider *always*
1180-
// dispatching to the interface. The default implementation of the interface's
1181-
// receiveFrame and sendPacket could just be calling avcodec_receive_frame and
1182-
// avcodec_send_packet. This would make the decoding loop even more generic.
1181+
// The default implementation uses avcodec_receive_frame and
1182+
// avcodec_send_packet, while specialized interfaces can override for
1183+
// hardware-specific optimizations.
11831184
while (true) {
1184-
if (deviceInterface_->canDecodePacketDirectly()) {
1185-
status = deviceInterface_->receiveFrame(avFrame);
1186-
} else {
1187-
status =
1188-
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
1189-
}
1185+
status = deviceInterface_->receiveFrame(avFrame);
11901186

11911187
if (status != AVSUCCESS && status != AVERROR(EAGAIN)) {
11921188
// Non-retriable error
@@ -1222,13 +1218,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
12221218

12231219
if (status == AVERROR_EOF) {
12241220
// End of file reached. We must drain the decoder
1225-
if (deviceInterface_->canDecodePacketDirectly()) {
1226-
status = deviceInterface_->sendEOFPacket();
1227-
} else {
1228-
status = avcodec_send_packet(
1229-
streamInfo.codecContext.get(),
1230-
/*avpkt=*/nullptr);
1231-
}
1221+
status = deviceInterface_->sendEOFPacket();
12321222
TORCH_CHECK(
12331223
status >= AVSUCCESS,
12341224
"Could not flush decoder: ",
@@ -1253,11 +1243,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
12531243

12541244
// We got a valid packet. Send it to the decoder, and we'll receive it in
12551245
// the next iteration.
1256-
if (deviceInterface_->canDecodePacketDirectly()) {
1257-
status = deviceInterface_->sendPacket(packet);
1258-
} else {
1259-
status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
1260-
}
1246+
status = deviceInterface_->sendPacket(packet);
12611247
TORCH_CHECK(
12621248
status >= AVSUCCESS,
12631249
"Could not push packet to decoder: ",

0 commit comments

Comments
 (0)