Skip to content

Commit 9efb767

Browse files
committed
Stragglers
1 parent 4b9f4c9 commit 9efb767

File tree

5 files changed

+21
-28
lines changed

5 files changed

+21
-28
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,20 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
156156
}
157157
}
158158

159-
void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
159+
void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
160160
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
161161
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
162162

163163
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
164164
timeBase_ = avStream->time_base;
165165

166+
auto cudaDevice = torch::Device(torch::kCUDA);
167+
defaultCudaInterface_ =
168+
std::unique_ptr<DeviceInterface>(createDeviceInterface(cudaDevice));
169+
AVCodecContext dummyCodecContext = {};
170+
defaultCudaInterface_->initialize(avStream);
171+
defaultCudaInterface_->registerHardwareDeviceWithCodec(&dummyCodecContext);
172+
166173
const AVCodecParameters* codecpar = avStream->codecpar;
167174
TORCH_CHECK(codecpar != nullptr, "CodecParameters cannot be null");
168175

@@ -523,8 +530,6 @@ void BetaCudaDeviceInterface::flush() {
523530
}
524531

525532
void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
526-
const VideoStreamOptions& videoStreamOptions,
527-
const AVRational& timeBase,
528533
UniqueAVFrame& avFrame,
529534
FrameOutput& frameOutput,
530535
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -535,20 +540,8 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
535540
// TODONVDEC P1: we use the 'default' cuda device interface for color
536541
// conversion. That's a temporary hack to make things work. we should abstract
537542
// the color conversion stuff separately.
538-
if (!defaultCudaInterface_) {
539-
auto cudaDevice = torch::Device(torch::kCUDA);
540-
defaultCudaInterface_ =
541-
std::unique_ptr<DeviceInterface>(createDeviceInterface(cudaDevice));
542-
AVCodecContext dummyCodecContext = {};
543-
defaultCudaInterface_->initializeContext(&dummyCodecContext);
544-
}
545-
546543
defaultCudaInterface_->convertAVFrameToFrameOutput(
547-
videoStreamOptions,
548-
timeBase,
549-
avFrame,
550-
frameOutput,
551-
preAllocatedOutputTensor);
544+
avFrame, frameOutput, preAllocatedOutputTensor);
552545
}
553546

554547
BetaCudaDeviceInterface::FrameBuffer::Slot*

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
3737
explicit BetaCudaDeviceInterface(const torch::Device& device);
3838
virtual ~BetaCudaDeviceInterface();
3939

40-
void initializeInterface(AVStream* stream) override;
40+
void initialize(const AVStream* avStream) override;
4141

4242
void convertAVFrameToFrameOutput(
43-
const VideoStreamOptions& videoStreamOptions,
44-
const AVRational& timeBase,
4543
UniqueAVFrame& avFrame,
4644
FrameOutput& frameOutput,
4745
std::optional<torch::Tensor> preAllocatedOutputTensor =

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,9 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4646
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
4747
}
4848

49-
void CpuDeviceInterface::initialize(
50-
[[maybe_unused]] AVCodecContext* codecContext,
51-
const AVRational& timeBase) {
52-
timeBase_ = timeBase;
49+
void CpuDeviceInterface::initialize(const AVStream* avStream) {
50+
TORCH_CHECK(avStream != nullptr, "avStream is null");
51+
timeBase_ = avStream->time_base;
5352
}
5453

5554
void CpuDeviceInterface::initializeVideo(

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ class CpuDeviceInterface : public DeviceInterface {
2323
return std::nullopt;
2424
}
2525

26-
virtual void initialize(
27-
[[maybe_unused]] AVCodecContext* codecContext,
28-
const AVRational& timeBase) override;
26+
virtual void initialize(const AVStream* avStream) override;
2927

3028
virtual void initializeVideo(
3129
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ class CudaDeviceInterface : public DeviceInterface {
2020

2121
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
2222

23-
void initialize(AVCodecContext* codecContext, const AVRational& timeBase)
24-
override;
23+
void initialize(const AVStream* avStream) override;
2524

2625
void initializeVideo(
2726
const VideoStreamOptions& videoStreamOptions,
@@ -30,6 +29,8 @@ class CudaDeviceInterface : public DeviceInterface {
3029
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims)
3130
override;
3231

32+
void registerHardwareDeviceWithCodec(AVCodecContext* codecContext) override;
33+
3334
void convertAVFrameToFrameOutput(
3435
UniqueAVFrame& avFrame,
3536
FrameOutput& frameOutput,
@@ -42,6 +43,10 @@ class CudaDeviceInterface : public DeviceInterface {
4243
// does this using filtergraph.
4344
UniqueAVFrame maybeConvertAVFrameToNV12OrRGB24(UniqueAVFrame& avFrame);
4445

46+
// We sometimes encounter frames that cannot be decoded on the CUDA device.
47+
// Rather than erroring out, we decode them on the CPU.
48+
std::unique_ptr<DeviceInterface> cpuInterface_;
49+
4550
VideoStreamOptions videoStreamOptions_;
4651
AVRational timeBase_;
4752

0 commit comments

Comments
 (0)