Skip to content

Commit d2e9bde

Browse files
committed
Refactor device interface, again.
1 parent ee3b9b7 commit d2e9bde

File tree

6 files changed

+55
-32
lines changed

6 files changed

+55
-32
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,15 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4848

4949
void CpuDeviceInterface::initialize(
5050
[[maybe_unused]] AVCodecContext* codecContext,
51+
const AVRational& timeBase) {
52+
timeBase_ = timeBase;
53+
}
54+
55+
void CpuDeviceInterface::initializeVideo(
5156
const VideoStreamOptions& videoStreamOptions,
5257
const std::vector<std::unique_ptr<Transform>>& transforms,
53-
const AVRational& timeBase,
5458
const std::optional<FrameDims>& resizedOutputDims) {
5559
videoStreamOptions_ = videoStreamOptions;
56-
timeBase_ = timeBase;
5760
resizedOutputDims_ = resizedOutputDims;
5861

5962
// We can only use swscale when we have a single resize transform. Note that

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ class CpuDeviceInterface : public DeviceInterface {
2525

2626
virtual void initialize(
2727
[[maybe_unused]] AVCodecContext* codecContext,
28+
const AVRational& timeBase) override;
29+
30+
virtual void initializeVideo(
2831
const VideoStreamOptions& videoStreamOptions,
2932
const std::vector<std::unique_ptr<Transform>>& transforms,
30-
const AVRational& timeBase,
3133
const std::optional<FrameDims>& resizedOutputDims) override;
3234

3335
void convertAVFrameToFrameOutput(
@@ -73,6 +75,14 @@ class CpuDeviceInterface : public DeviceInterface {
7375

7476
VideoStreamOptions videoStreamOptions_;
7577
AVRational timeBase_;
78+
79+
// If the resized output dimensions are present, then we always use those as
80+
// the output frame's dimensions. If they are not present, then we use the
81+
// dimensions of the raw decoded frame. Note that we do not know the
82+
// dimensions of the raw decoded frame until very late; we learn it in
83+
// convertAVFrameToFrameOutput(). Deciding the final output frame's actual
84+
// dimensions late allows us to handle video streams with variable
85+
// resolutions.
7686
std::optional<FrameDims> resizedOutputDims_;
7787

7888
// Color-conversion objects. Only one of filterGraph_ and swsContext_ should

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,18 +195,20 @@ CudaDeviceInterface::~CudaDeviceInterface() {
195195

196196
void CudaDeviceInterface::initialize(
197197
AVCodecContext* codecContext,
198-
const VideoStreamOptions& videoStreamOptions,
199-
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
200-
const AVRational& timeBase,
201-
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {
198+
const AVRational& timeBase) {
202199
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
203200
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
204-
205201
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
206-
videoStreamOptions_ = videoStreamOptions;
207202
timeBase_ = timeBase;
208203
}
209204

205+
void CudaDeviceInterface::initializeVideo(
206+
const VideoStreamOptions& videoStreamOptions,
207+
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>& transforms,
208+
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {
209+
videoStreamOptions_ = videoStreamOptions;
210+
}
211+
210212
UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12(
211213
UniqueAVFrame& avFrame) {
212214
// We need FFmpeg filters to handle those conversion cases which are not
@@ -220,13 +222,13 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12(
220222
return std::move(avFrame);
221223
}
222224

225+
auto hwFramesCtx =
226+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
223227
TORCH_CHECK(
224-
avFrame->hw_frames_ctx != nullptr,
228+
hwFramesCtx != nullptr,
225229
"The AVFrame does not have a hw_frames_ctx. "
226230
"That's unexpected, please report this to the TorchCodec repo.");
227231

228-
auto hwFramesCtx =
229-
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
230232
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
231233

232234
// If the frame is already in NV12 format, we don't need to do anything.
@@ -355,10 +357,10 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
355357
TORCH_CHECK(
356358
cpuInterface != nullptr, "Failed to create CPU device interface");
357359
cpuInterface->initialize(
358-
nullptr,
360+
/*codecContext=*/nullptr, timeBase_);
361+
cpuInterface->initializeVideo(
359362
VideoStreamOptions(),
360363
{},
361-
timeBase_,
362364
/*resizedOutputDims=*/std::nullopt);
363365

364366
cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ class CudaDeviceInterface : public DeviceInterface {
2020

2121
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
2222

23-
void initialize(
24-
AVCodecContext* codecContext,
23+
void initialize(AVCodecContext* codecContext, const AVRational& timeBase)
24+
override;
25+
26+
void initializeVideo(
2527
const VideoStreamOptions& videoStreamOptions,
2628
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>&
2729
transforms,
28-
const AVRational& timeBase,
2930
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims)
3031
override;
3132

src/torchcodec/_core/DeviceInterface.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,18 @@ class DeviceInterface {
3030

3131
virtual std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) = 0;
3232

33-
// Initialize the hardware device that is specified in `device`. Some builds
34-
// support CUDA and others only support CPU.
33+
// Initialize the device with parameters generic to all kinds of decoding.
3534
virtual void initialize(
3635
AVCodecContext* codecContext,
37-
const VideoStreamOptions& videoStreamOptions,
38-
const std::vector<std::unique_ptr<Transform>>& transforms,
39-
const AVRational& timeBase,
40-
const std::optional<FrameDims>& resizedOutputDims) = 0;
36+
const AVRational& timeBase) = 0;
37+
38+
// Initialize the device with parameters specific to video decoding. There is
39+
// a default empty implementation.
40+
virtual void initializeVideo(
41+
[[maybe_unused]] const VideoStreamOptions& videoStreamOptions,
42+
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>&
43+
transforms,
44+
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {}
4145

4246
virtual void convertAVFrameToFrameOutput(
4347
UniqueAVFrame& avFrame,

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -436,13 +436,20 @@ void SingleStreamDecoder::addStream(
436436
TORCH_CHECK(codecContext != nullptr);
437437
streamInfo.codecContext.reset(codecContext);
438438

439+
deviceInterface_->initialize(
440+
streamInfo.codecContext.get(), streamInfo.timeBase);
441+
439442
int retVal = avcodec_parameters_to_context(
440443
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
441444
TORCH_CHECK_EQ(retVal, AVSUCCESS);
442445

443446
streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0);
444447
streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base;
445448

449+
// Note that we must make sure to call avcodec_open2() AFTER we initialize
450+
// the device interface. Device initialization tells the codec context which
451+
// device to use. If we initialize the device interface after avcodec_open2(),
452+
// then all decoding may fall back to the CPU.
446453
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
447454
TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
448455

@@ -510,18 +517,14 @@ void SingleStreamDecoder::addVideoStream(
510517
if (transform->getOutputFrameDims().has_value()) {
511518
resizedOutputDims_ = transform->getOutputFrameDims().value();
512519
}
520+
521+
// Note that we are claiming ownership of the transform objects passed in to
522+
// us.
513523
transforms_.push_back(std::unique_ptr<Transform>(transform));
514524
}
515525

516-
// We initialize the device context late because we want to know a lot of
517-
// information that we can only know after resolving the codec, opening the
518-
// stream and inspecting the metadata.
519-
deviceInterface_->initialize(
520-
streamInfo.codecContext.get(),
521-
videoStreamOptions,
522-
transforms_,
523-
streamInfo.timeBase,
524-
resizedOutputDims_);
526+
deviceInterface_->initializeVideo(
527+
videoStreamOptions, transforms_, resizedOutputDims_);
525528
}
526529

527530
void SingleStreamDecoder::addAudioStream(

0 commit comments

Comments
 (0)