Skip to content

Commit 681b9cc

Browse files
committed
Refactor AVIOContextHolder
1 parent 0117a78 commit 681b9cc

File tree

4 files changed

+110
-136
lines changed

4 files changed

+110
-136
lines changed

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 14 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
7878
#endif
7979
}
8080

81-
AVIOBytesContext::AVIOBytesContext(
82-
const void* data,
83-
int64_t dataSize,
84-
int bufferSize)
85-
: dataContext_{static_cast<const uint8_t*>(data), dataSize, 0} {
86-
TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!");
87-
TORCH_CHECK(dataSize > 0, "Video data size must be positive");
88-
81+
void AVIOContextHolder::createAVIOContext(
82+
AVIOReadFunction read,
83+
AVIOSeekFunction seek,
84+
void* heldData,
85+
int bufferSize) {
86+
TORCH_CHECK(
87+
bufferSize > 0,
88+
"Buffer size must be greater than 0; is " + std::to_string(bufferSize));
8989
auto buffer = static_cast<uint8_t*>(av_malloc(bufferSize));
9090
TORCH_CHECK(
9191
buffer != nullptr,
@@ -95,74 +95,25 @@ AVIOBytesContext::AVIOBytesContext(
9595
buffer,
9696
bufferSize,
9797
0,
98-
&dataContext_,
99-
&AVIOBytesContext::read,
100-
nullptr,
101-
&AVIOBytesContext::seek));
98+
heldData,
99+
read,
100+
nullptr, // write function; not supported yet
101+
seek));
102102

103103
if (!avioContext_) {
104104
av_freep(&buffer);
105105
TORCH_CHECK(false, "Failed to allocate AVIOContext");
106106
}
107107
}
108108

109-
AVIOBytesContext::~AVIOBytesContext() {
109+
AVIOContextHolder::~AVIOContextHolder() {
110110
if (avioContext_) {
111111
av_freep(&avioContext_->buffer);
112112
}
113113
}
114114

115-
AVIOContext* AVIOBytesContext::getAVIOContext() const {
115+
AVIOContext* AVIOContextHolder::getAVIOContext() {
116116
return avioContext_.get();
117117
}
118118

119-
// The signature of this function is defined by FFmpeg.
120-
int AVIOBytesContext::read(void* opaque, uint8_t* buf, int buf_size) {
121-
auto dataContext = static_cast<DataContext*>(opaque);
122-
TORCH_CHECK(
123-
dataContext->current <= dataContext->size,
124-
"Tried to read outside of the buffer: current=",
125-
dataContext->current,
126-
", size=",
127-
dataContext->size);
128-
129-
buf_size = FFMIN(
130-
buf_size, static_cast<int>(dataContext->size - dataContext->current));
131-
TORCH_CHECK(
132-
buf_size >= 0,
133-
"Tried to read negative bytes: buf_size=",
134-
buf_size,
135-
", size=",
136-
dataContext->size,
137-
", current=",
138-
dataContext->current);
139-
140-
if (!buf_size) {
141-
return AVERROR_EOF;
142-
}
143-
memcpy(buf, dataContext->data + dataContext->current, buf_size);
144-
dataContext->current += buf_size;
145-
return buf_size;
146-
}
147-
148-
// The signature of this function is defined by FFmpeg.
149-
int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) {
150-
auto dataContext = static_cast<DataContext*>(opaque);
151-
int64_t ret = -1;
152-
153-
switch (whence) {
154-
case AVSEEK_SIZE:
155-
ret = dataContext->size;
156-
break;
157-
case SEEK_SET:
158-
dataContext->current = offset;
159-
ret = offset;
160-
break;
161-
default:
162-
break;
163-
}
164-
165-
return ret;
166-
}
167-
168119
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -145,43 +145,27 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext);
145145
// Returns true if sws_scale can handle unaligned data.
146146
bool canSwsScaleHandleUnalignedData();
147147

148+
using AVIOReadFunction = int (*)(void*, uint8_t*, int);
149+
using AVIOSeekFunction = int64_t (*)(void*, int64_t, int);
150+
148151
// TODO: explain purpose of context holder
149152
class AVIOContextHolder {
150153
public:
151-
virtual ~AVIOContextHolder(){};
152-
virtual AVIOContext* getAVIOContext() const = 0;
153-
};
154-
155-
// TODO: make comment below better
156-
// A struct that holds state for reading bytes from an IO context.
157-
// We give this to FFMPEG and it will pass it back to us when it needs to read
158-
// or seek in the memory buffer.
159-
//
160-
// A class that can be used as AVFormatContext's IO context. It reads from a
161-
// memory buffer that is passed in.
162-
class AVIOBytesContext : public AVIOContextHolder {
163-
public:
164-
AVIOBytesContext(const void* data, int64_t dataSize, int bufferSize);
165-
virtual ~AVIOBytesContext();
154+
virtual ~AVIOContextHolder();
155+
AVIOContext* getAVIOContext();
166156

167-
// Returns the AVIOContext that can be passed to FFMPEG.
168-
virtual AVIOContext* getAVIOContext() const override;
169-
170-
// The signature of this function is defined by FFMPEG.
171-
static int read(void* opaque, uint8_t* buf, int buf_size);
172-
173-
// The signature of this function is defined by FFMPEG.
174-
static int64_t seek(void* opaque, int64_t offset, int whence);
157+
protected:
158+
void createAVIOContext(
159+
AVIOReadFunction read,
160+
AVIOSeekFunction seek,
161+
void* heldData,
162+
int bufferSize = defaultBufferSize);
175163

176164
private:
177-
struct DataContext {
178-
const uint8_t* data;
179-
int64_t size;
180-
int64_t current;
181-
};
182-
183165
UniqueAVIOContext avioContext_;
184-
DataContext dataContext_;
166+
167+
// Defaults to 64 KB
168+
static const int defaultBufferSize = 64 * 1014;
185169
};
186170

187171
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/PyBindOps.cpp

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ struct PyObjectDeleter {
2828

2929
class AVIOFileLikeContext : public AVIOContextHolder {
3030
public:
31-
AVIOFileLikeContext(py::object fileLike, int bufferSize)
32-
: fileLikeContext_{
33-
std::unique_ptr<py::object, PyObjectDeleter>(
34-
new py::object(fileLike)),
35-
bufferSize} {
31+
explicit AVIOFileLikeContext(py::object fileLike)
32+
: fileLikeContext_{std::unique_ptr<py::object, PyObjectDeleter>(
33+
new py::object(fileLike))} {
3634
{
3735
// TODO: Is it necessary to acquire the GIL here? Is it maybe even
3836
// harmful? At the moment, this is only called from within a pybind
@@ -45,40 +43,11 @@ class AVIOFileLikeContext : public AVIOContextHolder {
4543
py::hasattr(fileLike, "seek"),
4644
"File like object must implement a seek method.");
4745
}
48-
49-
auto buffer = static_cast<uint8_t*>(av_malloc(bufferSize));
50-
TORCH_CHECK(
51-
buffer != nullptr,
52-
"Failed to allocate buffer of size " + std::to_string(bufferSize));
53-
54-
avioContext_.reset(avio_alloc_context(
55-
buffer,
56-
bufferSize,
57-
0,
58-
&fileLikeContext_,
59-
&AVIOFileLikeContext::read,
60-
nullptr,
61-
&AVIOFileLikeContext::seek));
62-
63-
if (!avioContext_) {
64-
av_freep(&buffer);
65-
TORCH_CHECK(false, "Failed to allocate AVIOContext");
66-
}
67-
}
68-
69-
virtual ~AVIOFileLikeContext() {
70-
if (avioContext_) {
71-
av_freep(&avioContext_->buffer);
72-
}
73-
}
74-
75-
virtual AVIOContext* getAVIOContext() const override {
76-
return avioContext_.get();
46+
createAVIOContext(&read, &seek, &fileLikeContext_);
7747
}
7848

7949
static int read(void* opaque, uint8_t* buf, int buf_size) {
8050
auto fileLikeContext = static_cast<FileLikeContext*>(opaque);
81-
buf_size = FFMIN(buf_size, fileLikeContext->bufferSize);
8251

8352
int num_read = 0;
8453
while (num_read < buf_size) {
@@ -126,10 +95,8 @@ class AVIOFileLikeContext : public AVIOContextHolder {
12695
//
12796
// https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors
12897
std::unique_ptr<py::object, PyObjectDeleter> fileLike;
129-
int bufferSize;
13098
};
13199

132-
UniqueAVIOContext avioContext_;
133100
FileLikeContext fileLikeContext_;
134101
};
135102

@@ -150,9 +117,7 @@ int64_t create_from_file_like(
150117
realSeek = seekModeFromString(seek_mode.value());
151118
}
152119

153-
constexpr int bufferSize = 64 * 1024;
154-
auto contextHolder =
155-
std::make_unique<AVIOFileLikeContext>(file_like, bufferSize);
120+
auto contextHolder = std::make_unique<AVIOFileLikeContext>(file_like);
156121

157122
VideoDecoder* decoder = new VideoDecoder(std::move(contextHolder), realSeek);
158123
return reinterpret_cast<int64_t>(decoder);

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,82 @@ TORCH_LIBRARY(torchcodec_ns, m) {
6464
}
6565

6666
namespace {
67+
68+
// TODO: make comment below better
69+
// A struct that holds state for reading bytes from an IO context.
70+
// We give this to FFMPEG and it will pass it back to us when it needs to read
71+
// or seek in the memory buffer.
72+
//
73+
// A class that can be used as AVFormatContext's IO context. It reads from a
74+
// memory buffer that is passed in.
75+
class AVIOBytesContext : public AVIOContextHolder {
76+
public:
77+
explicit AVIOBytesContext(const void* data, int64_t dataSize)
78+
: dataContext_{static_cast<const uint8_t*>(data), dataSize, 0} {
79+
TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!");
80+
TORCH_CHECK(dataSize > 0, "Video data size must be positive");
81+
createAVIOContext(&read, &seek, &dataContext_);
82+
}
83+
84+
// The signature of this function is defined by FFMPEG.
85+
static int read(void* opaque, uint8_t* buf, int buf_size) {
86+
auto dataContext = static_cast<DataContext*>(opaque);
87+
TORCH_CHECK(
88+
dataContext->current <= dataContext->size,
89+
"Tried to read outside of the buffer: current=",
90+
dataContext->current,
91+
", size=",
92+
dataContext->size);
93+
94+
buf_size = FFMIN(
95+
buf_size, static_cast<int>(dataContext->size - dataContext->current));
96+
TORCH_CHECK(
97+
buf_size >= 0,
98+
"Tried to read negative bytes: buf_size=",
99+
buf_size,
100+
", size=",
101+
dataContext->size,
102+
", current=",
103+
dataContext->current);
104+
105+
if (!buf_size) {
106+
return AVERROR_EOF;
107+
}
108+
memcpy(buf, dataContext->data + dataContext->current, buf_size);
109+
dataContext->current += buf_size;
110+
return buf_size;
111+
}
112+
113+
// The signature of this function is defined by FFMPEG.
114+
static int64_t seek(void* opaque, int64_t offset, int whence) {
115+
auto dataContext = static_cast<DataContext*>(opaque);
116+
int64_t ret = -1;
117+
118+
switch (whence) {
119+
case AVSEEK_SIZE:
120+
ret = dataContext->size;
121+
break;
122+
case SEEK_SET:
123+
dataContext->current = offset;
124+
ret = offset;
125+
break;
126+
default:
127+
break;
128+
}
129+
130+
return ret;
131+
}
132+
133+
private:
134+
struct DataContext {
135+
const uint8_t* data;
136+
int64_t size;
137+
int64_t current;
138+
};
139+
140+
DataContext dataContext_;
141+
};
142+
67143
at::Tensor wrapDecoderPointerToTensor(
68144
std::unique_ptr<VideoDecoder> uniqueDecoder) {
69145
VideoDecoder* decoder = uniqueDecoder.release();
@@ -135,9 +211,7 @@ at::Tensor create_from_tensor(
135211
realSeek = seekModeFromString(seek_mode.value());
136212
}
137213

138-
constexpr int bufferSize = 64 * 1024;
139-
auto contextHolder =
140-
std::make_unique<AVIOBytesContext>(data, length, bufferSize);
214+
auto contextHolder = std::make_unique<AVIOBytesContext>(data, length);
141215

142216
std::unique_ptr<VideoDecoder> uniqueDecoder =
143217
std::make_unique<VideoDecoder>(std::move(contextHolder), realSeek);

0 commit comments

Comments
 (0)