Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/torchcodec/_core/BetaCudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ class BetaCudaDeviceInterface : public DeviceInterface {
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

bool canDecodePacketDirectly() const override {
return true;
}

int sendPacket(ReferenceAVPacket& packet) override;
int sendEOFPacket() override;
int receiveFrame(UniqueAVFrame& avFrame) override;
Expand Down
50 changes: 27 additions & 23 deletions src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,42 +80,45 @@ class DeviceInterface {
// Extension points for custom decoding paths
// ------------------------------------------

// Override to return true if this device interface can decode packets
// directly. This means that the following two member functions can both
// be called:
//
// 1. sendPacket()
// 2. receiveFrame()
virtual bool canDecodePacketDirectly() const {
return false;
// Set the codec context for default FFmpeg decoding operations
// This must be called during initialization before using
// sendPacket/receiveFrame
virtual void setCodecContext(AVCodecContext* codecContext) {
codecContext_ = codecContext;
}

// Moral equivalent of avcodec_send_packet()
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or
// other AVERROR on failure
virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) {
TORCH_CHECK(
false,
"Send/receive packet decoding not implemented for this device interface");
return AVERROR(ENOSYS);
// Default implementation uses FFmpeg directly
virtual int sendPacket(ReferenceAVPacket& avPacket) {
if (!codecContext_) {
TORCH_CHECK(
false, "Codec context not available for default packet sending");
return AVERROR(EINVAL);
}
return avcodec_send_packet(codecContext_, avPacket.get());
}

// Send an EOF packet to flush the decoder
// Returns AVSUCCESS on success, or other AVERROR on failure
// Default implementation uses FFmpeg directly
virtual int sendEOFPacket() {
TORCH_CHECK(
false, "Send EOF packet not implemented for this device interface");
return AVERROR(ENOSYS);
if (!codecContext_) {
TORCH_CHECK(false, "Codec context not available for EOF packet sending");
return AVERROR(EINVAL);
}
return avcodec_send_packet(codecContext_, nullptr);
}

// Moral equivalent of avcodec_receive_frame()
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
// AVERROR_EOF if end of stream, or other AVERROR on failure
virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) {
TORCH_CHECK(
false,
"Send/receive packet decoding not implemented for this device interface");
return AVERROR(ENOSYS);
// Default implementation uses FFmpeg directly
virtual int receiveFrame(UniqueAVFrame& avFrame) {
if (!codecContext_) {
TORCH_CHECK(false, "Codec context not available for frame receiving");
return AVERROR(EINVAL);
}
return avcodec_receive_frame(codecContext_, avFrame.get());
}

// Flush remaining frames from decoder
Expand All @@ -126,6 +129,7 @@ class DeviceInterface {

protected:
torch::Device device_;
AVCodecContext* codecContext_ = nullptr; // Non-owning pointer
};

using CreateDeviceInterfaceFn =
Expand Down
34 changes: 10 additions & 24 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ void SingleStreamDecoder::addStream(

codecContext->time_base = streamInfo.stream->time_base;

// Set the codec context on the device interface for default FFmpeg
// implementations
deviceInterface_->setCodecContext(codecContext);

containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
std::string(avcodec_get_name(codecContext->codec_id));

Expand Down Expand Up @@ -1169,24 +1173,16 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
cursorWasJustSet_ = false;
}

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
UniqueAVFrame avFrame(av_frame_alloc());
AutoAVPacket autoAVPacket;
int status = AVSUCCESS;
bool reachedEOF = false;

// TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on
// if/else blocks to dispatch to the interface or to FFmpeg, consider *always*
// dispatching to the interface. The default implementation of the interface's
// receiveFrame and sendPacket could just be calling avcodec_receive_frame and
// avcodec_send_packet. This would make the decoding loop even more generic.
// The default implementation uses avcodec_receive_frame and
// avcodec_send_packet, while specialized interfaces can override for
// hardware-specific optimizations.
while (true) {
if (deviceInterface_->canDecodePacketDirectly()) {
status = deviceInterface_->receiveFrame(avFrame);
} else {
status =
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
}
status = deviceInterface_->receiveFrame(avFrame);

if (status != AVSUCCESS && status != AVERROR(EAGAIN)) {
// Non-retriable error
Expand Down Expand Up @@ -1222,13 +1218,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(

if (status == AVERROR_EOF) {
// End of file reached. We must drain the decoder
if (deviceInterface_->canDecodePacketDirectly()) {
status = deviceInterface_->sendEOFPacket();
} else {
status = avcodec_send_packet(
streamInfo.codecContext.get(),
/*avpkt=*/nullptr);
}
status = deviceInterface_->sendEOFPacket();
TORCH_CHECK(
status >= AVSUCCESS,
"Could not flush decoder: ",
Expand All @@ -1253,11 +1243,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(

// We got a valid packet. Send it to the decoder, and we'll receive it in
// the next iteration.
if (deviceInterface_->canDecodePacketDirectly()) {
status = deviceInterface_->sendPacket(packet);
} else {
status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
}
status = deviceInterface_->sendPacket(packet);
TORCH_CHECK(
status >= AVSUCCESS,
"Could not push packet to decoder: ",
Expand Down
Loading