Skip to content

Commit ef5fb6a

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into transform_benchmark
2 parents e559838 + c1798db commit ef5fb6a

File tree

12 files changed

+271
-98
lines changed

12 files changed

+271
-98
lines changed

.github/workflows/windows_wheel.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ jobs:
7171
# TODO: FFmpeg 5 on Windows segfaults in avcodec_open2() when passing
7272
# bad parameters.
7373
# See https://github.com/pytorch/torchcodec/pull/806
74-
# TODO: Support FFmpeg 8 on Windows
75-
ffmpeg-version-for-tests: ['4.4.2', '6.1.1', '7.0.1']
74+
ffmpeg-version-for-tests: ['4.4.2', '6.1.1', '7.0.1', '8.0']
7675
needs: build
7776
steps:
7877
- uses: actions/download-artifact@v4
@@ -83,7 +82,11 @@ jobs:
8382
uses: conda-incubator/setup-miniconda@v2
8483
with:
8584
auto-update-conda: true
86-
miniconda-version: "latest"
85+
# Using miniforge instead of miniconda ensures that the default
86+
# conda channel is conda-forge instead of main/default. This ensures
87+
# ABI consistency between dependencies:
88+
# https://conda-forge.org/docs/user/transitioning_from_defaults/
89+
miniforge-version: latest
8790
activate-environment: test
8891
python-version: ${{ matrix.python-version }}
8992
- name: Update pip

src/torchcodec/_core/AVIOTensorContext.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,34 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
1818
int read(void* opaque, uint8_t* buf, int buf_size) {
1919
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
2020
TORCH_CHECK(
21-
tensorContext->current <= tensorContext->data.numel(),
22-
"Tried to read outside of the buffer: current=",
23-
tensorContext->current,
21+
tensorContext->current_pos <= tensorContext->data.numel(),
22+
"Tried to read outside of the buffer: current_pos=",
23+
tensorContext->current_pos,
2424
", size=",
2525
tensorContext->data.numel());
2626

2727
int64_t numBytesRead = std::min(
2828
static_cast<int64_t>(buf_size),
29-
tensorContext->data.numel() - tensorContext->current);
29+
tensorContext->data.numel() - tensorContext->current_pos);
3030

3131
TORCH_CHECK(
3232
numBytesRead >= 0,
3333
"Tried to read negative bytes: numBytesRead=",
3434
numBytesRead,
3535
", size=",
3636
tensorContext->data.numel(),
37-
", current=",
38-
tensorContext->current);
37+
", current_pos=",
38+
tensorContext->current_pos);
3939

4040
if (numBytesRead == 0) {
4141
return AVERROR_EOF;
4242
}
4343

4444
std::memcpy(
4545
buf,
46-
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current,
46+
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current_pos,
4747
numBytesRead);
48-
tensorContext->current += numBytesRead;
48+
tensorContext->current_pos += numBytesRead;
4949
return numBytesRead;
5050
}
5151

@@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
5454
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
5555

5656
int64_t bufSize = static_cast<int64_t>(buf_size);
57-
if (tensorContext->current + bufSize > tensorContext->data.numel()) {
57+
if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) {
5858
TORCH_CHECK(
5959
tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE,
6060
"We tried to allocate an output encoded tensor larger than ",
@@ -68,13 +68,17 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
6868
}
6969

7070
TORCH_CHECK(
71-
tensorContext->current + bufSize <= tensorContext->data.numel(),
71+
tensorContext->current_pos + bufSize <= tensorContext->data.numel(),
7272
"Re-allocation of the output tensor didn't work. ",
7373
"This should not happen, please report on TorchCodec bug tracker");
7474

7575
uint8_t* outputTensorData = tensorContext->data.data_ptr<uint8_t>();
76-
std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
77-
tensorContext->current += bufSize;
76+
std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize);
77+
tensorContext->current_pos += bufSize;
78+
// Track the maximum position written so getOutputTensor's narrow() does not
79+
// truncate the file if final seek was backwards
80+
tensorContext->max_pos =
81+
std::max(tensorContext->current_pos, tensorContext->max_pos);
7882
return buf_size;
7983
}
8084

@@ -88,7 +92,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
8892
ret = tensorContext->data.numel();
8993
break;
9094
case SEEK_SET:
91-
tensorContext->current = offset;
95+
tensorContext->current_pos = offset;
9296
ret = offset;
9397
break;
9498
default:
@@ -101,7 +105,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
101105
} // namespace
102106

103107
AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
104-
: tensorContext_{data, 0} {
108+
: tensorContext_{data, 0, 0} {
105109
TORCH_CHECK(data.numel() > 0, "data must not be empty");
106110
TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
107111
TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
@@ -110,14 +114,17 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
110114
}
111115

112116
AVIOToTensorContext::AVIOToTensorContext()
113-
: tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
117+
: tensorContext_{
118+
torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}),
119+
0,
120+
0} {
114121
createAVIOContext(
115122
nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true);
116123
}
117124

118125
torch::Tensor AVIOToTensorContext::getOutputTensor() {
119126
return tensorContext_.data.narrow(
120-
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current);
127+
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos);
121128
}
122129

123130
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOTensorContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ namespace detail {
1515

1616
struct TensorContext {
1717
torch::Tensor data;
18-
int64_t current;
18+
int64_t current_pos;
19+
int64_t max_pos;
1920
};
2021

2122
} // namespace detail

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,10 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
6060

6161
// Create hardware device context
6262
c10::cuda::CUDAGuard deviceGuard(device);
63-
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
64-
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
65-
// So we ensure the deviceIndex is not negative.
6663
// We set the device because we may be called from a different thread than
6764
// the one that initialized the cuda context.
68-
cudaSetDevice(deviceIndex);
65+
TORCH_CHECK(
66+
cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device");
6967
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
7068
std::string deviceOrdinal = std::to_string(deviceIndex);
7169

src/torchcodec/_core/Encoder.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
#include "src/torchcodec/_core/Encoder.h"
55
#include "torch/types.h"
66

7-
extern "C" {
8-
#include <libavutil/pixdesc.h>
9-
}
10-
117
namespace facebook::torchcodec {
128

139
namespace {
@@ -542,10 +538,17 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
542538
} // namespace
543539

544540
VideoEncoder::~VideoEncoder() {
541+
// TODO-VideoEncoder: Unify destructor with ~AudioEncoder()
545542
if (avFormatContext_ && avFormatContext_->pb) {
546-
avio_flush(avFormatContext_->pb);
547-
avio_close(avFormatContext_->pb);
548-
avFormatContext_->pb = nullptr;
543+
if (avFormatContext_->pb->error == 0) {
544+
avio_flush(avFormatContext_->pb);
545+
}
546+
if (!avioContextHolder_) {
547+
if (avFormatContext_->pb->error == 0) {
548+
avio_close(avFormatContext_->pb);
549+
}
550+
avFormatContext_->pb = nullptr;
551+
}
549552
}
550553
}
551554

@@ -581,6 +584,36 @@ VideoEncoder::VideoEncoder(
581584
initializeEncoder(videoStreamOptions);
582585
}
583586

587+
VideoEncoder::VideoEncoder(
588+
const torch::Tensor& frames,
589+
int frameRate,
590+
std::string_view formatName,
591+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
592+
const VideoStreamOptions& videoStreamOptions)
593+
: frames_(validateFrames(frames)),
594+
inFrameRate_(frameRate),
595+
avioContextHolder_(std::move(avioContextHolder)) {
596+
setFFmpegLogLevel();
597+
// Map mkv -> matroska when used as format name
598+
formatName = (formatName == "mkv") ? "matroska" : formatName;
599+
AVFormatContext* avFormatContext = nullptr;
600+
int status = avformat_alloc_output_context2(
601+
&avFormatContext, nullptr, formatName.data(), nullptr);
602+
603+
TORCH_CHECK(
604+
avFormatContext != nullptr,
605+
"Couldn't allocate AVFormatContext. ",
606+
"Check the desired format? Got format=",
607+
formatName,
608+
". ",
609+
getFFMPEGErrorStringFromErrorCode(status));
610+
avFormatContext_.reset(avFormatContext);
611+
612+
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
613+
614+
initializeEncoder(videoStreamOptions);
615+
}
616+
584617
void VideoEncoder::initializeEncoder(
585618
const VideoStreamOptions& videoStreamOptions) {
586619
const AVCodec* avCodec =
@@ -751,6 +784,17 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
751784
return avFrame;
752785
}
753786

787+
torch::Tensor VideoEncoder::encodeToTensor() {
788+
TORCH_CHECK(
789+
avioContextHolder_ != nullptr,
790+
"Cannot encode to tensor, avio tensor context doesn't exist.");
791+
encode();
792+
auto avioToTensorContext =
793+
dynamic_cast<AVIOToTensorContext*>(avioContextHolder_.get());
794+
TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
795+
return avioToTensorContext->getOutputTensor();
796+
}
797+
754798
void VideoEncoder::encodeFrame(
755799
AutoAVPacket& autoAVPacket,
756800
const UniqueAVFrame& avFrame) {

src/torchcodec/_core/Encoder.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,17 @@ class VideoEncoder {
141141
std::string_view fileName,
142142
const VideoStreamOptions& videoStreamOptions);
143143

144+
VideoEncoder(
145+
const torch::Tensor& frames,
146+
int frameRate,
147+
std::string_view formatName,
148+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
149+
const VideoStreamOptions& videoStreamOptions);
150+
144151
void encode();
145152

153+
torch::Tensor encodeToTensor();
154+
146155
private:
147156
void initializeEncoder(const VideoStreamOptions& videoStreamOptions);
148157
UniqueAVFrame convertTensorToAVFrame(
@@ -167,6 +176,8 @@ class VideoEncoder {
167176
int outHeight_ = -1;
168177
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
169178

179+
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
180+
170181
bool encodeWasCalled_ = false;
171182
};
172183

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
encode_audio_to_file_like,
2727
encode_audio_to_tensor,
2828
encode_video_to_file,
29+
encode_video_to_tensor,
2930
get_ffmpeg_library_versions,
3031
get_frame_at_index,
3132
get_frame_at_pts,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3232
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3333
m.def(
3434
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
35-
m.def(
36-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
3735
m.def(
3836
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
3937
m.def(
4038
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
39+
m.def(
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
41+
m.def(
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
4143
m.def(
4244
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4345
m.def(
@@ -520,21 +522,6 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
520522
return makeOpsAudioFramesOutput(result);
521523
}
522524

523-
void encode_video_to_file(
524-
const at::Tensor& frames,
525-
int64_t frame_rate,
526-
std::string_view file_name,
527-
std::optional<int64_t> crf = std::nullopt) {
528-
VideoStreamOptions videoStreamOptions;
529-
videoStreamOptions.crf = crf;
530-
VideoEncoder(
531-
frames,
532-
validateInt64ToInt(frame_rate, "frame_rate"),
533-
file_name,
534-
videoStreamOptions)
535-
.encode();
536-
}
537-
538525
void encode_audio_to_file(
539526
const at::Tensor& samples,
540527
int64_t sample_rate,
@@ -609,6 +596,38 @@ void _encode_audio_to_file_like(
609596
encoder.encode();
610597
}
611598

599+
void encode_video_to_file(
600+
const at::Tensor& frames,
601+
int64_t frame_rate,
602+
std::string_view file_name,
603+
std::optional<int64_t> crf = std::nullopt) {
604+
VideoStreamOptions videoStreamOptions;
605+
videoStreamOptions.crf = crf;
606+
VideoEncoder(
607+
frames,
608+
validateInt64ToInt(frame_rate, "frame_rate"),
609+
file_name,
610+
videoStreamOptions)
611+
.encode();
612+
}
613+
614+
at::Tensor encode_video_to_tensor(
615+
const at::Tensor& frames,
616+
int64_t frame_rate,
617+
std::string_view format,
618+
std::optional<int64_t> crf = std::nullopt) {
619+
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
620+
VideoStreamOptions videoStreamOptions;
621+
videoStreamOptions.crf = crf;
622+
return VideoEncoder(
623+
frames,
624+
validateInt64ToInt(frame_rate, "frame_rate"),
625+
format,
626+
std::move(avioContextHolder),
627+
videoStreamOptions)
628+
.encodeToTensor();
629+
}
630+
612631
// For testing only. We need to implement this operation as a core library
613632
// function because what we're testing is round-tripping pts values as
614633
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -869,9 +888,10 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
869888

870889
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
871890
m.impl("encode_audio_to_file", &encode_audio_to_file);
872-
m.impl("encode_video_to_file", &encode_video_to_file);
873891
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
874892
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
893+
m.impl("encode_video_to_file", &encode_video_to_file);
894+
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
875895
m.impl("seek_to_pts", &seek_to_pts);
876896
m.impl("add_video_stream", &add_video_stream);
877897
m.impl("_add_video_stream", &_add_video_stream);

0 commit comments

Comments
 (0)