Skip to content

Commit 42c6373

Browse files
committed
Allow output tensor re-allocation
1 parent d85baa2 commit 42c6373

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

src/torchcodec/_core/AVIOBytesContext.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,32 @@ int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) {
6868
}
6969

7070
AVIOToTensorContext::AVIOToTensorContext()
71-
: dataContext_{torch::empty({OUTPUT_TENSOR_SIZE}, {torch::kUInt8}), 0} {
71+
: dataContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
7272
createAVIOContext(nullptr, &write, &seek, &dataContext_);
7373
}
7474

7575
// The signature of this function is defined by FFMPEG.
7676
int AVIOToTensorContext::write(void* opaque, uint8_t* buf, int buf_size) {
7777
auto dataContext = static_cast<DataContext*>(opaque);
78+
79+
if (dataContext->current + buf_size > dataContext->outputTensor.numel()) {
80+
TORCH_CHECK(
81+
dataContext->outputTensor.numel() * 2 <= MAX_TENSOR_SIZE,
82+
"We tried to allocate an output encoded tensor larger than ",
83+
MAX_TENSOR_SIZE,
84+
" bytes. If you think this should be supported, please report.");
85+
86+
// We double the size of the outpout tensor. Calling cat() may not be the
87+
// most efficient, but it's simple.
88+
dataContext->outputTensor =
89+
torch::cat({dataContext->outputTensor, dataContext->outputTensor});
90+
}
91+
7892
TORCH_CHECK(
79-
dataContext->current + buf_size <= OUTPUT_TENSOR_SIZE,
80-
"Can't encode more, output tensor needs to be re-allocated and this isn't supported yet.");
93+
dataContext->current + buf_size <= dataContext->outputTensor.numel(),
94+
"Re-allocation of the output tensor didn't work. ",
95+
"This should not happen, please report on TorchCodec bug tracker");
96+
8197
uint8_t* outputTensorData = dataContext->outputTensor.data_ptr<uint8_t>();
8298
std::memcpy(outputTensorData + dataContext->current, buf, buf_size);
8399
dataContext->current += static_cast<int64_t>(buf_size);

src/torchcodec/_core/AVIOBytesContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class AVIOToTensorContext : public AVIOContextHolder {
4242
int64_t current;
4343
};
4444

45-
static const int OUTPUT_TENSOR_SIZE = 5'000'000; // TODO-ENCODING handle this
45+
static const int INITIAL_TENSOR_SIZE = 10'000'000; // 10MB
46+
static const int MAX_TENSOR_SIZE = 320'000'000; // 320 MB
4647
static int write(void* opaque, uint8_t* buf, int buf_size);
4748
// We need to expose seek() for some formats like mp3.
4849
static int64_t seek(void* opaque, int64_t offset, int whence);

test/test_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,23 @@ def test_tensor_against_file(self, asset, bit_rate, output_format, tmp_path):
12241224
self.decode(encoded_file), self.decode(encoded_tensor)
12251225
)
12261226

1227+
def test_encode_to_tensor_long_output(self):
1228+
# Check that we support re-allocating the output tensor when the encoded
1229+
# data is large.
1230+
samples = torch.rand(1, int(1e7))
1231+
encoded_tensor = encode_audio_to_tensor(
1232+
wf=samples,
1233+
sample_rate=16_000,
1234+
format="flac",
1235+
bit_rate=44_000,
1236+
)
1237+
# Note: this should be in sync with its C++ counterpart for the test to
1238+
# be meaningful.
1239+
INITIAL_TENSOR_SIZE = 10_000_000
1240+
assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE
1241+
1242+
torch.testing.assert_close(self.decode(encoded_tensor), samples)
1243+
12271244

12281245
if __name__ == "__main__":
12291246
pytest.main()

0 commit comments

Comments
 (0)