@@ -163,13 +163,20 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
163163 }
164164}
165165
166- void BetaCudaDeviceInterface::initializeInterface ( AVStream* avStream) {
166+ void BetaCudaDeviceInterface::initialize ( const AVStream* avStream) {
167167 torch::Tensor dummyTensorForCudaInitialization = torch::empty (
168168 {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
169169
170170 TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
171171 timeBase_ = avStream->time_base ;
172172
173+ auto cudaDevice = torch::Device (torch::kCUDA );
174+ defaultCudaInterface_ =
175+ std::unique_ptr<DeviceInterface>(createDeviceInterface (cudaDevice));
176+ AVCodecContext dummyCodecContext = {};
177+ defaultCudaInterface_->initialize (avStream);
178+ defaultCudaInterface_->registerHardwareDeviceWithCodec (&dummyCodecContext);
179+
173180 const AVCodecParameters* codecpar = avStream->codecpar ;
174181 TORCH_CHECK (codecpar != nullptr , " CodecParameters cannot be null" );
175182
@@ -489,8 +496,6 @@ void BetaCudaDeviceInterface::flush() {
489496}
490497
491498void BetaCudaDeviceInterface::convertAVFrameToFrameOutput (
492- const VideoStreamOptions& videoStreamOptions,
493- const AVRational& timeBase,
494499 UniqueAVFrame& avFrame,
495500 FrameOutput& frameOutput,
496501 std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -501,20 +506,8 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
501506 // TODONVDEC P1: we use the 'default' cuda device interface for color
502507 // conversion. That's a temporary hack to make things work. we should abstract
503508 // the color conversion stuff separately.
504- if (!defaultCudaInterface_) {
505- auto cudaDevice = torch::Device (torch::kCUDA );
506- defaultCudaInterface_ =
507- std::unique_ptr<DeviceInterface>(createDeviceInterface (cudaDevice));
508- AVCodecContext dummyCodecContext = {};
509- defaultCudaInterface_->initializeContext (&dummyCodecContext);
510- }
511-
512509 defaultCudaInterface_->convertAVFrameToFrameOutput (
513- videoStreamOptions,
514- timeBase,
515- avFrame,
516- frameOutput,
517- preAllocatedOutputTensor);
510+ avFrame, frameOutput, preAllocatedOutputTensor);
518511}
519512
520513} // namespace facebook::torchcodec
0 commit comments