Skip to content

Commit ca1f538

Browse files
committed
changes
1 parent 408b373 commit ca1f538

File tree

12 files changed

+179
-31
lines changed

12 files changed

+179
-31
lines changed

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ class CpuDeviceInterface : public DeviceInterface {
1818

1919
virtual ~CpuDeviceInterface() {}
2020

21-
std::optional<const AVCodec*> findCodec(
22-
[[maybe_unused]] const AVCodecID& codecId) override {
23-
return std::nullopt;
24-
}
25-
2621
virtual void initialize(
2722
const AVStream* avStream,
2823
const UniqueDecodingAVFormatContext& avFormatCtx,

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,40 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
329329
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
330330
}
331331

332+
namespace {
333+
// Helper function to check if a codec supports CUDA hardware acceleration
334+
bool codecSupportsCudaHardware(const AVCodec* codec) {
335+
const AVCodecHWConfig* config = nullptr;
336+
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; ++j) {
337+
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
338+
return true;
339+
}
340+
}
341+
return false;
342+
}
343+
} // namespace
344+
332345
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
333346
// we have to do this because of an FFmpeg bug where hardware decoding is not
334347
// appropriately set, so we just go off and find the matching codec for the CUDA
335348
// device
336-
std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
349+
350+
std::optional<const AVCodec*> CudaDeviceInterface::findEncoder(
351+
const AVCodecID& codecId) {
352+
void* i = nullptr;
353+
const AVCodec* codec = nullptr;
354+
while ((codec = av_codec_iterate(&i)) != nullptr) {
355+
if (codec->id != codecId || !av_codec_is_encoder(codec)) {
356+
continue;
357+
}
358+
if (codecSupportsCudaHardware(codec)) {
359+
return codec;
360+
}
361+
}
362+
return std::nullopt;
363+
}
364+
365+
std::optional<const AVCodec*> CudaDeviceInterface::findDecoder(
337366
const AVCodecID& codecId) {
338367
void* i = nullptr;
339368
const AVCodec* codec = nullptr;
@@ -342,12 +371,8 @@ std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
342371
continue;
343372
}
344373

345-
const AVCodecHWConfig* config = nullptr;
346-
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr;
347-
++j) {
348-
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
349-
return codec;
350-
}
374+
if (codecSupportsCudaHardware(codec)) {
375+
return codec;
351376
}
352377
}
353378

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class CudaDeviceInterface : public DeviceInterface {
1818

1919
virtual ~CudaDeviceInterface();
2020

21-
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
21+
std::optional<const AVCodec*> findEncoder(const AVCodecID& codecId) override;
22+
std::optional<const AVCodec*> findDecoder(const AVCodecID& codecId) override;
2223

2324
void initialize(
2425
const AVStream* avStream,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ class DeviceInterface {
4646
return device_;
4747
};
4848

49-
virtual std::optional<const AVCodec*> findCodec(
49+
virtual std::optional<const AVCodec*> findEncoder(
50+
[[maybe_unused]] const AVCodecID& codecId) {
51+
return std::nullopt;
52+
};
53+
54+
virtual std::optional<const AVCodec*> findDecoder(
5055
[[maybe_unused]] const AVCodecID& codecId) {
5156
return std::nullopt;
5257
};

src/torchcodec/_core/Encoder.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,12 @@ VideoEncoder::VideoEncoder(
724724

725725
void VideoEncoder::initializeEncoder(
726726
const VideoStreamOptions& videoStreamOptions) {
727+
deviceInterface_ = createDeviceInterface(
728+
videoStreamOptions.device, videoStreamOptions.deviceVariant);
729+
TORCH_CHECK(
730+
deviceInterface_ != nullptr,
731+
"Failed to create device interface. This should never happen, please report.");
732+
727733
const AVCodec* avCodec = nullptr;
728734
// If codec arg is provided, find codec using logic similar to FFmpeg:
729735
// https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835
@@ -749,7 +755,13 @@ void VideoEncoder::initializeEncoder(
749755
avFormatContext_->oformat != nullptr,
750756
"Output format is null, unable to find default codec.");
751757
avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec);
758+
// TODO: merge above logic w this logic
759+
// Try to find a hardware-accelerated encoder if not using CPU
760+
if (videoStreamOptions.device.type() != torch::kCPU) {
761+
avCodec = deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec).value_or(avCodec);
752762
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
763+
}
764+
753765
}
754766

755767
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
@@ -820,6 +832,11 @@ void VideoEncoder::initializeEncoder(
820832
videoStreamOptions.preset.value().c_str(),
821833
0);
822834
}
835+
836+
// Register the hardware device context with the codec
837+
// context before calling avcodec_open2().
838+
deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get());
839+
823840
int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions);
824841
av_dict_free(&avCodecOptions);
825842

src/torchcodec/_core/Encoder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <map>
44
#include <string>
55
#include "AVIOContextHolder.h"
6+
#include "DeviceInterface.h"
67
#include "FFMPEGCommon.h"
78
#include "StreamOptions.h"
89

@@ -183,6 +184,7 @@ class VideoEncoder {
183184
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
184185

185186
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
187+
std::unique_ptr<DeviceInterface> deviceInterface_;
186188

187189
bool encodeWasCalled_ = false;
188190
AVDictionary* avFormatOptions_ = nullptr;

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ AVPacket* ReferenceAVPacket::operator->() {
4040

4141
AVCodecOnlyUseForCallingAVFindBestStream
4242
makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) {
43-
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100)
43+
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) // FFmpeg < 5.0.3
4444
return const_cast<AVCodec*>(codec);
4545
#else
4646
return codec;

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ void SingleStreamDecoder::addStream(
462462
// addStream() which is supposed to be generic
463463
if (mediaType == AVMEDIA_TYPE_VIDEO) {
464464
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
465-
deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
465+
deviceInterface_->findDecoder(streamInfo.stream->codecpar->codec_id)
466466
.value_or(avCodec));
467467
}
468468

src/torchcodec/_core/custom_ops.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\",str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -641,6 +641,7 @@ void encode_video_to_file(
641641
const at::Tensor& frames,
642642
int64_t frame_rate,
643643
std::string_view file_name,
644+
std::string_view device = "cpu",
644645
std::optional<std::string_view> codec = std::nullopt,
645646
std::optional<std::string_view> pixel_format = std::nullopt,
646647
std::optional<double> crf = std::nullopt,
@@ -650,6 +651,8 @@ void encode_video_to_file(
650651
videoStreamOptions.codec = std::move(codec);
651652
videoStreamOptions.pixelFormat = std::move(pixel_format);
652653
videoStreamOptions.crf = crf;
654+
655+
videoStreamOptions.device = torch::Device(std::string(device));
653656
videoStreamOptions.preset = preset;
654657

655658
if (extra_options.has_value()) {
@@ -669,6 +672,7 @@ at::Tensor encode_video_to_tensor(
669672
const at::Tensor& frames,
670673
int64_t frame_rate,
671674
std::string_view format,
675+
std::string_view device = "cpu",
672676
std::optional<std::string_view> codec = std::nullopt,
673677
std::optional<std::string_view> pixel_format = std::nullopt,
674678
std::optional<double> crf = std::nullopt,
@@ -679,6 +683,8 @@ at::Tensor encode_video_to_tensor(
679683
videoStreamOptions.codec = std::move(codec);
680684
videoStreamOptions.pixelFormat = std::move(pixel_format);
681685
videoStreamOptions.crf = crf;
686+
687+
videoStreamOptions.device = torch::Device(std::string(device));
682688
videoStreamOptions.preset = preset;
683689

684690
if (extra_options.has_value()) {
@@ -700,6 +706,7 @@ void _encode_video_to_file_like(
700706
int64_t frame_rate,
701707
std::string_view format,
702708
int64_t file_like_context,
709+
std::string_view device = "cpu",
703710
std::optional<std::string_view> codec = std::nullopt,
704711
std::optional<std::string_view> pixel_format = std::nullopt,
705712
std::optional<double> crf = std::nullopt,
@@ -715,6 +722,7 @@ void _encode_video_to_file_like(
715722
videoStreamOptions.codec = std::move(codec);
716723
videoStreamOptions.pixelFormat = std::move(pixel_format);
717724
videoStreamOptions.crf = crf;
725+
videoStreamOptions.device = torch::Device(std::string(device));
718726
videoStreamOptions.preset = preset;
719727

720728
if (extra_options.has_value()) {

src/torchcodec/_core/ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def encode_video_to_file_like(
213213
frame_rate: int,
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216+
device: Optional[str] = "cpu",
216217
codec: Optional[str] = None,
217218
pixel_format: Optional[str] = None,
218219
crf: Optional[Union[int, float]] = None,
@@ -226,6 +227,7 @@ def encode_video_to_file_like(
226227
frame_rate: Frame rate in frames per second
227228
format: Video format (e.g., "mp4", "mov", "mkv")
228229
file_like: File-like object that supports write() and seek() methods
230+
device: Device to use for encoding (default: "cpu")
229231
codec: Optional codec name (e.g., "libx264", "h264")
230232
pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p")
231233
crf: Optional constant rate factor for encoding quality
@@ -239,6 +241,7 @@ def encode_video_to_file_like(
239241
frame_rate,
240242
format,
241243
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
244+
device,
242245
codec,
243246
pixel_format,
244247
crf,
@@ -331,11 +334,12 @@ def encode_video_to_file_abstract(
331334
frames: torch.Tensor,
332335
frame_rate: int,
333336
filename: str,
337+
device: str = "cpu",
334338
codec: Optional[str] = None,
335339
pixel_format: Optional[str] = None,
336340
preset: Optional[str] = None,
337341
crf: Optional[Union[int, float]] = None,
338-
extra_options: Optional[list[str]] = None,
342+
extra_options: Optional[list[str]] = None = None,
339343
) -> None:
340344
return
341345

@@ -345,11 +349,12 @@ def encode_video_to_tensor_abstract(
345349
frames: torch.Tensor,
346350
frame_rate: int,
347351
format: str,
352+
device: str = "cpu",
348353
codec: Optional[str] = None,
349354
pixel_format: Optional[str] = None,
350355
preset: Optional[str] = None,
351356
crf: Optional[Union[int, float]] = None,
352-
extra_options: Optional[list[str]] = None,
357+
extra_options: Optional[list[str]] = None = None,
353358
) -> torch.Tensor:
354359
return torch.empty([], dtype=torch.long)
355360

@@ -360,6 +365,7 @@ def _encode_video_to_file_like_abstract(
360365
frame_rate: int,
361366
format: str,
362367
file_like_context: int,
368+
device: str = "cpu",
363369
codec: Optional[str] = None,
364370
pixel_format: Optional[str] = None,
365371
preset: Optional[str] = None,

0 commit comments

Comments
 (0)