Skip to content

Commit baa9798

Browse files
authored
Reorganize public APIs in videoDecoder.h (#477)
1 parent 746401c commit baa9798

File tree

1 file changed

+83
-92
lines changed

1 file changed

+83
-92
lines changed

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 83 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,7 @@
1616

1717
namespace facebook::torchcodec {
1818

19-
/*
20-
The VideoDecoder class can be used to decode video frames to Tensors.
21-
22-
Example usage of this class:
23-
std::string video_file_path = "/path/to/video.mp4";
24-
VideoDecoder video_decoder = VideoDecoder::createFromFilePath(video_file_path);
25-
26-
// After creating the decoder, we can query the metadata:
27-
auto metadata = video_decoder.getContainerMetadata();
28-
29-
// We can also add streams to the decoder:
30-
// -1 sets the default stream.
31-
video_decoder.addVideoStreamDecoder(-1);
32-
33-
// API for seeking and frame extraction:
34-
// Let's extract the first frame at or after pts=5.0 seconds.
35-
video_decoder.setCursorPtsInSeconds(5.0);
36-
auto output = video_decoder->getNextFrameOutput();
37-
torch::Tensor frame = output.frame;
38-
double presentation_timestamp = output.ptsSeconds;
39-
// Note that presentation_timestamp can be any timestamp at 5.0 or above
40-
// because the frame time may not align exactly with the seek time.
41-
CHECK_GE(presentation_timestamp, 5.0);
42-
*/
19+
// The VideoDecoder class can be used to decode video frames to Tensors.
4320
// Note that VideoDecoder is not thread-safe.
4421
// Do not call non-const APIs concurrently on the same object.
4522
class VideoDecoder {
@@ -52,17 +29,16 @@ class VideoDecoder {
5229

5330
enum class SeekMode { exact, approximate };
5431

55-
// Creates a VideoDecoder from the video at videoFilePath.
5632
explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode);
57-
58-
// Creates a VideoDecoder from a given buffer. Note that the buffer is not
59-
// owned by the VideoDecoder.
6033
explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode);
6134

35+
// Creates a VideoDecoder from the video at videoFilePath.
6236
static std::unique_ptr<VideoDecoder> createFromFilePath(
6337
const std::string& videoFilePath,
6438
SeekMode seekMode = SeekMode::exact);
6539

40+
// Creates a VideoDecoder from a given buffer. Note that the buffer is not
41+
// owned by the VideoDecoder.
6642
static std::unique_ptr<VideoDecoder> createFromBuffer(
6743
const void* buffer,
6844
size_t length,
@@ -71,8 +47,10 @@ class VideoDecoder {
7147
// --------------------------------------------------------------------------
7248
// VIDEO METADATA QUERY API
7349
// --------------------------------------------------------------------------
50+
7451
// Updates the metadata of the video to accurate values obtained by scanning
75-
// the contents of the video file.
52+
// the contents of the video file. Also updates each StreamInfo's index, i.e.
53+
// the allFrames and keyFrames vectors.
7654
void scanFileAndUpdateMetadataAndIndex();
7755

7856
struct StreamMetadata {
@@ -88,7 +66,6 @@ class VideoDecoder {
8866
std::optional<int64_t> numKeyFrames;
8967
std::optional<double> averageFps;
9068
std::optional<double> bitRate;
91-
std::optional<std::vector<int64_t>> keyFrames;
9269

9370
// More accurate duration, obtained by scanning the file.
9471
// These presentation timestamps are in time base.
@@ -126,6 +103,7 @@ class VideoDecoder {
126103
// --------------------------------------------------------------------------
127104
// ADDING STREAMS API
128105
// --------------------------------------------------------------------------
106+
129107
enum ColorConversionLibrary {
130108
// TODO: Add an AUTO option later.
131109
// Use the libavfilter library for color conversion.
@@ -164,96 +142,71 @@ class VideoDecoder {
164142
int streamIndex,
165143
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());
166144

167-
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
168-
169-
// ---- SINGLE FRAME SEEK AND DECODING API ----
170-
// Places the cursor at the first frame on or after the position in seconds.
171-
// Calling getNextFrameNoDemuxInternal() will return the first frame at
172-
// or after this position.
173-
void setCursorPtsInSeconds(double seconds);
174-
175-
// This structure ensures we always keep the streamIndex and AVFrame together
176-
// Note that AVFrame itself doesn't retain the streamIndex.
177-
struct AVFrameStream {
178-
// The actual decoded output as a unique pointer to an AVFrame.
179-
UniqueAVFrame avFrame;
180-
// The stream index of the decoded frame.
181-
int streamIndex;
182-
183-
explicit AVFrameStream(UniqueAVFrame&& a, int s)
184-
: avFrame(std::move(a)), streamIndex(s) {}
185-
};
145+
// --------------------------------------------------------------------------
146+
// DECODING AND SEEKING APIs
147+
// --------------------------------------------------------------------------
186148

149+
// All public decoding entry points return either a FrameOutput or a
150+
// FrameBatchOutput.
151+
// They are the equivalent of the user-facing Frame and FrameBatch classes in
152+
// Python. They contain RGB decoded frames along with some associated data
153+
// like PTS and duration.
187154
struct FrameOutput {
188-
// The actual decoded output as a Tensor.
189-
torch::Tensor data;
190-
// The stream index of the decoded frame. Used to distinguish
191-
// between streams that are of the same type.
155+
torch::Tensor data; // 3D: of shape CHW or HWC.
192156
int streamIndex;
193-
// The presentation timestamp of the decoded frame in seconds.
194157
double ptsSeconds;
195-
// The duration of the decoded frame in seconds.
196158
double durationSeconds;
197159
};
198160

199161
struct FrameBatchOutput {
200-
torch::Tensor data;
201-
torch::Tensor ptsSeconds;
202-
torch::Tensor durationSeconds;
162+
torch::Tensor data; // 4D: of shape NCHW or NHWC.
163+
torch::Tensor ptsSeconds; // 1D of shape (N,)
164+
torch::Tensor durationSeconds; // 1D of shape (N,)
203165

204166
explicit FrameBatchOutput(
205167
int64_t numFrames,
206168
const VideoStreamOptions& videoStreamOptions,
207169
const StreamMetadata& streamMetadata);
208170
};
209171

210-
class EndOfFileException : public std::runtime_error {
211-
public:
212-
explicit EndOfFileException(const std::string& msg)
213-
: std::runtime_error(msg) {}
214-
};
172+
// Places the cursor at the first frame on or after the position in seconds.
173+
// Calling getNextFrameNoDemux() will return the first frame at
174+
// or after this position.
175+
void setCursorPtsInSeconds(double seconds);
215176

216177
// Decodes the frame where the current cursor position is. It also advances
217178
// the cursor to the next frame.
218179
FrameOutput getNextFrameNoDemux();
219-
// Decodes the first frame in any added stream that is visible at a given
220-
// timestamp. Frames in the video have a presentation timestamp and a
221-
// duration. For example, if a frame has presentation timestamp of 5.0s and a
222-
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
223-
// i.e. it will be returned when this function is called with seconds=5.0 or
224-
// seconds=5.999, etc.
225-
FrameOutput getFramePlayedAtNoDemux(double seconds);
226180

227181
FrameOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
228-
// This is morally private but needs to be exposed for C++ tests. Once
229-
// getFrameAtIndex supports the preAllocatedOutputTensor parameter, we can
230-
// move it back to private.
231-
FrameOutput getFrameAtIndexInternal(
232-
int streamIndex,
233-
int64_t frameIndex,
234-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
235182

236183
// Returns frames at the given indices for a given stream as a single stacked
237184
// Tensor.
238185
FrameBatchOutput getFramesAtIndices(
239186
int streamIndex,
240187
const std::vector<int64_t>& frameIndices);
241188

189+
// Returns frames within a given range. The range is defined by [start, stop).
190+
// The values retrieved from the range are: [start, start+step,
191+
// start+(2*step), start+(3*step), ..., stop). The default for step is 1.
192+
FrameBatchOutput
193+
getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step);
194+
195+
// Decodes the first frame in any added stream that is visible at a given
196+
// timestamp. Frames in the video have a presentation timestamp and a
197+
// duration. For example, if a frame has presentation timestamp of 5.0s and a
198+
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
199+
// i.e. it will be returned when this function is called with seconds=5.0 or
200+
// seconds=5.999, etc.
201+
FrameOutput getFramePlayedAtNoDemux(double seconds);
202+
242203
FrameBatchOutput getFramesPlayedAt(
243204
int streamIndex,
244205
const std::vector<double>& timestamps);
245206

246-
// Returns frames within a given range for a given stream as a single stacked
247-
// Tensor. The range is defined by [start, stop). The values retrieved from
248-
// the range are:
249-
// [start, start+step, start+(2*step), start+(3*step), ..., stop)
250-
// The default for step is 1.
251-
FrameBatchOutput
252-
getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step);
253-
254-
// Returns frames within a given pts range for a given stream as a single
255-
// stacked tensor. The range is defined by [startSeconds, stopSeconds) with
256-
// respect to the pts values for frames. The returned frames are in pts order.
207+
// Returns frames within a given pts range. The range is defined by
208+
// [startSeconds, stopSeconds) with respect to the pts values for frames. The
209+
// returned frames are in pts order.
257210
//
258211
// Note that while stopSeconds is excluded in the half open range, this really
259212
// only makes a difference when stopSeconds is exactly the pts value for a
@@ -273,11 +226,47 @@ class VideoDecoder {
273226
double startSeconds,
274227
double stopSeconds);
275228

229+
class EndOfFileException : public std::runtime_error {
230+
public:
231+
explicit EndOfFileException(const std::string& msg)
232+
: std::runtime_error(msg) {}
233+
};
234+
276235
// --------------------------------------------------------------------------
277-
// DECODER PERFORMANCE STATISTICS API
236+
// MORALLY PRIVATE APIS
278237
// --------------------------------------------------------------------------
238+
// These are APIs that should be private, but that are effectively exposed for
239+
// practical reasons, typically for testing purposes.
240+
241+
// This struct is needed because AVFrame doesn't retain the streamIndex. Only
242+
// the AVPacket knows its stream. This is what the low-level private decoding
243+
// entry points return. The AVFrameStream is then converted to a FrameOutput
244+
// with convertAVFrameToFrameOutput. It should be private, but is currently
245+
// used by DeviceInterface.
246+
struct AVFrameStream {
247+
// The actual decoded output as a unique pointer to an AVFrame.
248+
// Usually, this is a YUV frame. It'll be converted to RGB in
249+
// convertAVFrameToFrameOutput.
250+
UniqueAVFrame avFrame;
251+
// The stream index of the decoded frame.
252+
int streamIndex;
279253

280-
// Only exposed for performance testing.
254+
explicit AVFrameStream(UniqueAVFrame&& a, int s)
255+
: avFrame(std::move(a)), streamIndex(s) {}
256+
};
257+
258+
// Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we
259+
// can move it back to private.
260+
FrameOutput getFrameAtIndexInternal(
261+
int streamIndex,
262+
int64_t frameIndex,
263+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
264+
265+
// Exposed for _test_frame_pts_equality, which is used to test non-regression
266+
// of pts resolution (64 to 32 bit floats)
267+
double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex);
268+
269+
// Exposed for performance testing.
281270
struct DecodeStats {
282271
int64_t numSeeksAttempted = 0;
283272
int64_t numSeeksDone = 0;
@@ -291,9 +280,9 @@ class VideoDecoder {
291280
DecodeStats getDecodeStats() const;
292281
void resetDecodeStats();
293282

294-
double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex);
295-
296283
private:
284+
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
285+
297286
struct FrameInfo {
298287
int64_t pts = 0;
299288
// The value of this default is important: the last frame's nextPts will be
@@ -404,8 +393,10 @@ class VideoDecoder {
404393
const enum AVColorSpace colorspace);
405394

406395
void maybeSeekToBeforeDesiredPts();
396+
407397
AVFrameStream decodeAVFrame(
408398
std::function<bool(int, AVFrame*)> filterFunction);
399+
409400
// Once we create a decoder can update the metadata with the codec context.
410401
// For example, for video streams, we can add the height and width of the
411402
// decoded stream.

0 commit comments

Comments
 (0)