Skip to content

Commit cfea9ab

Browse files
committed
Support encoding into tensor
1 parent 8467b92 commit cfea9ab

File tree

12 files changed

+236
-93
lines changed

12 files changed

+236
-93
lines changed

src/torchcodec/_core/AVIOBytesContext.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ AVIOBytesContext::AVIOBytesContext(const void* data, int64_t dataSize)
1313
: dataContext_{static_cast<const uint8_t*>(data), dataSize, 0} {
1414
TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!");
1515
TORCH_CHECK(dataSize > 0, "Video data size must be positive");
16-
createAVIOContext(&read, &seek, &dataContext_);
16+
createAVIOContext(&read, nullptr, &seek, &dataContext_);
1717
}
1818

1919
// The signature of this function is defined by FFMPEG.
@@ -67,4 +67,46 @@ int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) {
6767
return ret;
6868
}
6969

70+
AVIOToTensorContext::AVIOToTensorContext()
71+
: dataContext_{torch::empty({OUTPUT_TENSOR_SIZE}, {torch::kUInt8}), 0} {
72+
createAVIOContext(nullptr, &write, &seek, &dataContext_);
73+
}
74+
75+
// The signature of this function is defined by FFMPEG.
76+
int AVIOToTensorContext::write(void* opaque, uint8_t* buf, int buf_size) {
77+
auto dataContext = static_cast<DataContext*>(opaque);
78+
TORCH_CHECK(
79+
dataContext->current + buf_size <= OUTPUT_TENSOR_SIZE,
80+
"Can't encode more, output tensor needs to be re-allocated and this isn't supported yet.");
81+
uint8_t* outputTensorData = dataContext->outputTensor.data_ptr<uint8_t>();
82+
std::memcpy(outputTensorData + dataContext->current, buf, buf_size);
83+
dataContext->current += static_cast<int64_t>(buf_size);
84+
return buf_size;
85+
}
86+
87+
// The signature of this function is defined by FFMPEG.
88+
int64_t AVIOToTensorContext::seek(void* opaque, int64_t offset, int whence) {
89+
auto dataContext = static_cast<DataContext*>(opaque);
90+
int64_t ret = -1;
91+
92+
switch (whence) {
93+
case AVSEEK_SIZE:
94+
ret = dataContext->outputTensor.numel();
95+
break;
96+
case SEEK_SET:
97+
dataContext->current = offset;
98+
ret = offset;
99+
break;
100+
default:
101+
break;
102+
}
103+
104+
return ret;
105+
}
106+
107+
torch::Tensor AVIOToTensorContext::getOutputTensor() {
108+
return dataContext_.outputTensor.narrow(
109+
/*dim=*/0, /*start=*/0, /*length=*/dataContext_.current);
110+
}
111+
70112
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOBytesContext.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#pragma once
88

9+
#include <torch/types.h>
910
#include "src/torchcodec/_core/AVIOContextHolder.h"
1011

1112
namespace facebook::torchcodec {
@@ -29,4 +30,25 @@ class AVIOBytesContext : public AVIOContextHolder {
2930
DataContext dataContext_;
3031
};
3132

33+
class AVIOToTensorContext : public AVIOContextHolder {
34+
public:
35+
explicit AVIOToTensorContext();
36+
torch::Tensor getOutputTensor();
37+
38+
private:
39+
// Should this class be tensor-aware? Or should we just store a uint8* buffer
40+
// instead of the tensor? If it's not tensor-aware it means we need to do the
41+
// (re)allocation outside of it. Same for the call to narrow().
42+
struct DataContext {
43+
torch::Tensor outputTensor;
44+
int64_t current;
45+
};
46+
47+
static const int OUTPUT_TENSOR_SIZE = 5'000'000; // TODO-ENCODING handle this
48+
static int write(void* opaque, uint8_t* buf, int buf_size);
49+
static int64_t seek(void* opaque, int64_t offset, int whence);
50+
51+
DataContext dataContext_;
52+
};
53+
3254
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOContextHolder.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace facebook::torchcodec {
1111

1212
void AVIOContextHolder::createAVIOContext(
1313
AVIOReadFunction read,
14+
AVIOWriteFunction write,
1415
AVIOSeekFunction seek,
1516
void* heldData,
1617
int bufferSize) {
@@ -22,13 +23,17 @@ void AVIOContextHolder::createAVIOContext(
2223
buffer != nullptr,
2324
"Failed to allocate buffer of size " + std::to_string(bufferSize));
2425

26+
TORCH_CHECK(
27+
write != nullptr ^ (read != nullptr && seek != nullptr),
28+
"read and seek methods must be defined, or write method must be defined. "
29+
"But not both!")
2530
avioContext_.reset(avio_alloc_context(
2631
buffer,
2732
bufferSize,
28-
0,
33+
/*write_flag=*/write != nullptr,
2934
heldData,
3035
read,
31-
nullptr, // write function; not supported yet
36+
write,
3237
seek));
3338

3439
if (!avioContext_) {

src/torchcodec/_core/AVIOContextHolder.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ namespace facebook::torchcodec {
1919
// freed.
2020
// 2. It is a base class for AVIOContext specializations. When specializing a
2121
// AVIOContext, we need to provide four things:
22-
// 1. A read callback function.
23-
// 2. A seek callback function.
24-
// 3. A write callback function. (Not supported yet; it's for encoding.)
22+
// 1. A read callback function, for decoding.
23+
// 2. A seek callback function, for decoding and encoding.
24+
// 3. A write callback function, for encoding>
2525
// 4. A pointer to some context object that has the same lifetime as the
2626
// AVIOContext itself. This context object holds the custom state that
2727
// tracks the custom behavior of reading, seeking and writing. It is
@@ -46,11 +46,13 @@ class AVIOContextHolder {
4646

4747
// These signatures are defined by FFmpeg.
4848
using AVIOReadFunction = int (*)(void*, uint8_t*, int);
49+
using AVIOWriteFunction = int (*)(void*, uint8_t*, int);
4950
using AVIOSeekFunction = int64_t (*)(void*, int64_t, int);
5051

5152
// Deriving classes should call this function in their constructor.
5253
void createAVIOContext(
5354
AVIOReadFunction read,
55+
AVIOWriteFunction write,
5456
AVIOSeekFunction seek,
5557
void* heldData,
5658
int bufferSize = defaultBufferSize);

src/torchcodec/_core/AVIOFileLikeContext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike)
2323
py::hasattr(fileLike, "seek"),
2424
"File like object must implement a seek method.");
2525
}
26-
createAVIOContext(&read, &seek, &fileLike_);
26+
createAVIOContext(&read, nullptr, &seek, &fileLike_);
2727
}
2828

2929
int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {

src/torchcodec/_core/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ find_package(pybind11 REQUIRED)
88
find_package(Torch REQUIRED)
99
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
1010

11-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
11+
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
12+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
1213

1314
function(make_torchcodec_sublibrary
1415
library_name
@@ -60,11 +61,13 @@ function(make_torchcodec_libraries
6061
set(decoder_sources
6162
AVIOContextHolder.cpp
6263
FFMPEGCommon.cpp
63-
DeviceInterface.cpp
64+
DeviceInterface.cpp
6465
SingleStreamDecoder.cpp
6566
# TODO: lib name should probably not be "*_decoder*" now that it also
6667
# contains an encoder
6768
Encoder.cpp
69+
# TODO-Encoding remove from here. Should only be needed in custom_ops.cpp
70+
AVIOBytesContext.cpp
6871
)
6972

7073
if(ENABLE_CUDA)

src/torchcodec/_core/Encoder.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <sstream>
22

3+
#include "src/torchcodec/_core/AVIOBytesContext.h"
34
#include "src/torchcodec/_core/Encoder.h"
45
#include "torch/types.h"
56

@@ -40,9 +41,13 @@ AudioEncoder::~AudioEncoder() {}
4041
AudioEncoder::AudioEncoder(
4142
const torch::Tensor wf,
4243
int sampleRate,
43-
std::string_view fileName,
44+
std::optional<std::string_view> fileName,
45+
std::optional<std::string_view> formatName,
4446
std::optional<int64_t> bit_rate)
4547
: wf_(wf) {
48+
TORCH_CHECK(
49+
fileName.has_value() ^ formatName.has_value(),
50+
"Pass one of filename OR format, not both.");
4651
TORCH_CHECK(
4752
wf_.dtype() == torch::kFloat32,
4853
"waveform must have float32 dtype, got ",
@@ -54,25 +59,32 @@ AudioEncoder::AudioEncoder(
5459

5560
setFFmpegLogLevel();
5661
AVFormatContext* avFormatContext = nullptr;
57-
auto status = avformat_alloc_output_context2(
58-
&avFormatContext, nullptr, nullptr, fileName.data());
62+
int status = AVSUCCESS;
63+
if (fileName.has_value()) {
64+
status = avformat_alloc_output_context2(
65+
&avFormatContext, nullptr, nullptr, fileName->data());
66+
} else {
67+
status = avformat_alloc_output_context2(
68+
&avFormatContext, nullptr, formatName->data(), nullptr);
69+
}
5970
TORCH_CHECK(
6071
avFormatContext != nullptr,
6172
"Couldn't allocate AVFormatContext. ",
6273
"Check the desired extension? ",
6374
getFFMPEGErrorStringFromErrorCode(status));
6475
avFormatContext_.reset(avFormatContext);
6576

66-
// TODO-ENCODING: Should also support encoding into bytes (use
67-
// AVIOBytesContext)
68-
TORCH_CHECK(
69-
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
70-
"AVFMT_NOFILE is set. We only support writing to a file.");
71-
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
72-
TORCH_CHECK(
73-
status >= 0,
74-
"avio_open failed: ",
75-
getFFMPEGErrorStringFromErrorCode(status));
77+
if (fileName.has_value()) {
78+
status =
79+
avio_open(&avFormatContext_->pb, fileName->data(), AVIO_FLAG_WRITE);
80+
TORCH_CHECK(
81+
status >= 0,
82+
"avio_open failed: ",
83+
getFFMPEGErrorStringFromErrorCode(status));
84+
} else {
85+
avioContextHolder_ = std::make_unique<AVIOToTensorContext>();
86+
avFormatContext->pb = avioContextHolder_->getAVIOContext();
87+
}
7688

7789
// We use the AVFormatContext's default codec for that
7890
// specific format/container.
@@ -168,7 +180,18 @@ AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
168180
return avCodec.sample_fmts[0];
169181
}
170182

183+
torch::Tensor AudioEncoder::encodeToTensor() {
184+
TORCH_CHECK(
185+
avioContextHolder_ != nullptr,
186+
"Cannot encode to tensor, avio context doesn't exist.");
187+
encode();
188+
return avioContextHolder_->getOutputTensor();
189+
}
190+
171191
void AudioEncoder::encode() {
192+
// TODO-ENCODING: Need to check, but consecutive calls to encode() are
193+
// probably invalid. We can address this once we (re)design the public and
194+
// private encoding APIs.
172195
UniqueAVFrame avFrame(av_frame_alloc());
173196
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
174197
// Default to 256 like in torchaudio

src/torchcodec/_core/Encoder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <torch/types.h>
3+
#include "src/torchcodec/_core/AVIOBytesContext.h"
34
#include "src/torchcodec/_core/FFMPEGCommon.h"
45

56
namespace facebook::torchcodec {
@@ -19,9 +20,11 @@ class AudioEncoder {
1920
// match this, and that's up to the user. If sample rates don't match,
2021
// encoding will still work but audio will be distorted.
2122
int sampleRate,
22-
std::string_view fileName,
23+
std::optional<std::string_view> fileName,
24+
std::optional<std::string_view> formatName,
2325
std::optional<int64_t> bit_rate = std::nullopt);
2426
void encode();
27+
torch::Tensor encodeToTensor();
2528

2629
private:
2730
void encodeInnerLoop(
@@ -36,5 +39,8 @@ class AudioEncoder {
3639
UniqueSwrContext swrContext_;
3740

3841
const torch::Tensor wf_;
42+
43+
// Stores the AVIOContext for the output tensor buffer.
44+
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
3945
};
4046
} // namespace facebook::torchcodec

src/torchcodec/_core/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
_test_frame_pts_equality,
1919
add_audio_stream,
2020
add_video_stream,
21-
create_audio_encoder,
2221
create_from_bytes,
2322
create_from_file,
2423
create_from_file_like,
2524
create_from_tensor,
26-
encode_audio,
25+
encode_audio_to_file,
26+
encode_audio_to_tensor,
2727
get_ffmpeg_library_versions,
2828
get_frame_at_index,
2929
get_frame_at_pts,

0 commit comments

Comments
 (0)