Skip to content

Commit 50938fe

Browse files
committed
Support encoding into tensor
1 parent 8467b92 commit 50938fe

File tree

12 files changed

+211
-101
lines changed

12 files changed

+211
-101
lines changed

src/torchcodec/_core/AVIOBytesContext.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ AVIOBytesContext::AVIOBytesContext(const void* data, int64_t dataSize)
1313
: dataContext_{static_cast<const uint8_t*>(data), dataSize, 0} {
1414
TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!");
1515
TORCH_CHECK(dataSize > 0, "Video data size must be positive");
16-
createAVIOContext(&read, &seek, &dataContext_);
16+
createAVIOContext(&read, nullptr, &seek, &dataContext_);
1717
}
1818

1919
// The signature of this function is defined by FFMPEG.
@@ -67,4 +67,26 @@ int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) {
6767
return ret;
6868
}
6969

70+
AVIOToTensorContext::AVIOToTensorContext()
71+
: dataContext_{torch::empty({OUTPUT_TENSOR_SIZE}, {torch::kUInt8}), 0} {
72+
createAVIOContext(nullptr, &write, nullptr, &dataContext_);
73+
}
74+
75+
// The signature of this function is defined by FFMPEG.
76+
int AVIOToTensorContext::write(void* opaque, uint8_t* buf, int buf_size) {
77+
auto dataContext = static_cast<DataContext*>(opaque);
78+
TORCH_CHECK(
79+
dataContext->current + buf_size <= OUTPUT_TENSOR_SIZE,
80+
"Can't encode more, output tensor needs to be re-allocated and this isn't supported yet.");
81+
uint8_t* outputTensorData = dataContext->outputTensor.data_ptr<uint8_t>();
82+
std::memcpy(outputTensorData + dataContext->current, buf, buf_size);
83+
dataContext->current += static_cast<int64_t>(buf_size);
84+
return buf_size;
85+
}
86+
87+
torch::Tensor AVIOToTensorContext::getOutputTensor() {
88+
return dataContext_.outputTensor.narrow(
89+
/*dim=*/0, /*start=*/0, /*length=*/dataContext_.current);
90+
}
91+
7092
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOBytesContext.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#pragma once
88

9+
#include <torch/types.h>
910
#include "src/torchcodec/_core/AVIOContextHolder.h"
1011

1112
namespace facebook::torchcodec {
@@ -29,4 +30,24 @@ class AVIOBytesContext : public AVIOContextHolder {
2930
DataContext dataContext_;
3031
};
3132

33+
class AVIOToTensorContext : public AVIOContextHolder {
34+
public:
35+
explicit AVIOToTensorContext();
36+
torch::Tensor getOutputTensor();
37+
38+
private:
39+
// Should this class be tensor-aware? Or should we just store a uint8* buffer
40+
// instead of the tensor? If it's not tensor-aware it means we need to do the
41+
// (re)allocation outside of it. Same for the call to narrow().
42+
struct DataContext {
43+
torch::Tensor outputTensor;
44+
int64_t current;
45+
};
46+
47+
static const int OUTPUT_TENSOR_SIZE = 5'000'000; // TODO-ENCODING handle this
48+
static int write(void* opaque, uint8_t* buf, int buf_size);
49+
50+
DataContext dataContext_;
51+
};
52+
3253
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOContextHolder.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace facebook::torchcodec {
1111

1212
void AVIOContextHolder::createAVIOContext(
1313
AVIOReadFunction read,
14+
AVIOWriteFunction write,
1415
AVIOSeekFunction seek,
1516
void* heldData,
1617
int bufferSize) {
@@ -22,13 +23,17 @@ void AVIOContextHolder::createAVIOContext(
2223
buffer != nullptr,
2324
"Failed to allocate buffer of size " + std::to_string(bufferSize));
2425

26+
TORCH_CHECK(
27+
write != nullptr ^ (read != nullptr && seek != nullptr),
28+
"read and seek methods must be defined, or write method must be defined. "
29+
"But not both!")
2530
avioContext_.reset(avio_alloc_context(
2631
buffer,
2732
bufferSize,
28-
0,
33+
/*write_flag=*/write != nullptr,
2934
heldData,
3035
read,
31-
nullptr, // write function; not supported yet
36+
write,
3237
seek));
3338

3439
if (!avioContext_) {

src/torchcodec/_core/AVIOContextHolder.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ namespace facebook::torchcodec {
1818
// UniqueAVIOContext, as the AVIOContext points to a buffer which must be
1919
// freed.
2020
// 2. It is a base class for AVIOContext specializations. When specializing a
21-
// AVIOContext, we need to provide four things:
22-
// 1. A read callback function.
23-
// 2. A seek callback function.
24-
// 3. A write callback function. (Not supported yet; it's for encoding.)
25-
// 4. A pointer to some context object that has the same lifetime as the
21+
// AVIOContext, we need to provide:
22+
// 1. - For decoding: A read callback function and a seek callback
23+
// function.
24+
// - For encoding: A write callback function.
25+
// 2. A pointer to some context object that has the same lifetime as the
2626
// AVIOContext itself. This context object holds the custom state that
2727
// tracks the custom behavior of reading, seeking and writing. It is
2828
// provided upon AVIOContext creation and to the read, seek and
@@ -46,11 +46,13 @@ class AVIOContextHolder {
4646

4747
// These signatures are defined by FFmpeg.
4848
using AVIOReadFunction = int (*)(void*, uint8_t*, int);
49+
using AVIOWriteFunction = int (*)(void*, uint8_t*, int);
4950
using AVIOSeekFunction = int64_t (*)(void*, int64_t, int);
5051

5152
// Deriving classes should call this function in their constructor.
5253
void createAVIOContext(
5354
AVIOReadFunction read,
55+
AVIOWriteFunction write,
5456
AVIOSeekFunction seek,
5557
void* heldData,
5658
int bufferSize = defaultBufferSize);

src/torchcodec/_core/AVIOFileLikeContext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike)
2323
py::hasattr(fileLike, "seek"),
2424
"File like object must implement a seek method.");
2525
}
26-
createAVIOContext(&read, &seek, &fileLike_);
26+
createAVIOContext(&read, nullptr, &seek, &fileLike_);
2727
}
2828

2929
int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {

src/torchcodec/_core/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ find_package(pybind11 REQUIRED)
88
find_package(Torch REQUIRED)
99
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
1010

11-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
11+
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
12+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
1213

1314
function(make_torchcodec_sublibrary
1415
library_name
@@ -60,11 +61,13 @@ function(make_torchcodec_libraries
6061
set(decoder_sources
6162
AVIOContextHolder.cpp
6263
FFMPEGCommon.cpp
63-
DeviceInterface.cpp
64+
DeviceInterface.cpp
6465
SingleStreamDecoder.cpp
6566
# TODO: lib name should probably not be "*_decoder*" now that it also
6667
# contains an encoder
6768
Encoder.cpp
69+
# TODO-Encoding remove from here. Should only be needed in custom_ops.cpp
70+
AVIOBytesContext.cpp
6871
)
6972

7073
if(ENABLE_CUDA)

src/torchcodec/_core/Encoder.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <sstream>
22

3+
#include "src/torchcodec/_core/AVIOBytesContext.h"
34
#include "src/torchcodec/_core/Encoder.h"
45
#include "torch/types.h"
56

@@ -40,9 +41,13 @@ AudioEncoder::~AudioEncoder() {}
4041
AudioEncoder::AudioEncoder(
4142
const torch::Tensor wf,
4243
int sampleRate,
43-
std::string_view fileName,
44+
std::optional<std::string_view> fileName,
45+
std::optional<std::string_view> formatName,
4446
std::optional<int64_t> bit_rate)
4547
: wf_(wf) {
48+
TORCH_CHECK(
49+
fileName.has_value() ^ formatName.has_value(),
50+
"Pass one of filename OR format, not both.");
4651
TORCH_CHECK(
4752
wf_.dtype() == torch::kFloat32,
4853
"waveform must have float32 dtype, got ",
@@ -52,27 +57,35 @@ AudioEncoder::AudioEncoder(
5257
TORCH_CHECK(
5358
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
5459

60+
avioContextHolder_ = std::make_unique<AVIOToTensorContext>();
61+
5562
setFFmpegLogLevel();
5663
AVFormatContext* avFormatContext = nullptr;
57-
auto status = avformat_alloc_output_context2(
58-
&avFormatContext, nullptr, nullptr, fileName.data());
64+
int status = AVSUCCESS;
65+
if (fileName.has_value()) {
66+
status = avformat_alloc_output_context2(
67+
&avFormatContext, nullptr, nullptr, fileName->data());
68+
} else {
69+
status = avformat_alloc_output_context2(
70+
&avFormatContext, nullptr, formatName->data(), nullptr);
71+
}
5972
TORCH_CHECK(
6073
avFormatContext != nullptr,
6174
"Couldn't allocate AVFormatContext. ",
6275
"Check the desired extension? ",
6376
getFFMPEGErrorStringFromErrorCode(status));
6477
avFormatContext_.reset(avFormatContext);
6578

66-
// TODO-ENCODING: Should also support encoding into bytes (use
67-
// AVIOBytesContext)
68-
TORCH_CHECK(
69-
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
70-
"AVFMT_NOFILE is set. We only support writing to a file.");
71-
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
72-
TORCH_CHECK(
73-
status >= 0,
74-
"avio_open failed: ",
75-
getFFMPEGErrorStringFromErrorCode(status));
79+
if (fileName.has_value()) {
80+
status =
81+
avio_open(&avFormatContext_->pb, fileName->data(), AVIO_FLAG_WRITE);
82+
TORCH_CHECK(
83+
status >= 0,
84+
"avio_open failed: ",
85+
getFFMPEGErrorStringFromErrorCode(status));
86+
} else {
87+
avFormatContext->pb = avioContextHolder_->getAVIOContext();
88+
}
7689

7790
// We use the AVFormatContext's default codec for that
7891
// specific format/container.
@@ -168,7 +181,18 @@ AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
168181
return avCodec.sample_fmts[0];
169182
}
170183

184+
torch::Tensor AudioEncoder::encodeToTensor() {
185+
TORCH_CHECK(
186+
avioContextHolder_ != nullptr,
187+
"Cannot encode to tensor, avio context doesn't exist.");
188+
encode();
189+
return avioContextHolder_->getOutputTensor();
190+
}
191+
171192
void AudioEncoder::encode() {
193+
// TODO-ENCODING: Need to check, but consecutive calls to encode() are
194+
// probably invalid. We can address this once we (re)design the public and
195+
// private encoding APIs.
172196
UniqueAVFrame avFrame(av_frame_alloc());
173197
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
174198
// Default to 256 like in torchaudio

src/torchcodec/_core/Encoder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <torch/types.h>
3+
#include "src/torchcodec/_core/AVIOBytesContext.h"
34
#include "src/torchcodec/_core/FFMPEGCommon.h"
45

56
namespace facebook::torchcodec {
@@ -19,9 +20,11 @@ class AudioEncoder {
1920
// match this, and that's up to the user. If sample rates don't match,
2021
// encoding will still work but audio will be distorted.
2122
int sampleRate,
22-
std::string_view fileName,
23+
std::optional<std::string_view> fileName,
24+
std::optional<std::string_view> formatName,
2325
std::optional<int64_t> bit_rate = std::nullopt);
2426
void encode();
27+
torch::Tensor encodeToTensor();
2528

2629
private:
2730
void encodeInnerLoop(
@@ -36,5 +39,8 @@ class AudioEncoder {
3639
UniqueSwrContext swrContext_;
3740

3841
const torch::Tensor wf_;
42+
43+
// Stores the AVIOContext for the output tensor buffer.
44+
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
3945
};
4046
} // namespace facebook::torchcodec

src/torchcodec/_core/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
_test_frame_pts_equality,
1919
add_audio_stream,
2020
add_video_stream,
21-
create_audio_encoder,
2221
create_from_bytes,
2322
create_from_file,
2423
create_from_file_like,
2524
create_from_tensor,
26-
encode_audio,
25+
encode_audio_to_file,
26+
encode_audio_to_tensor,
2727
get_ffmpeg_library_versions,
2828
get_frame_at_index,
2929
get_frame_at_pts,

0 commit comments

Comments
 (0)