Skip to content

Commit e5713c0

Browse files
committed
Reorgnize public stuff
1 parent 288bb83 commit e5713c0

File tree

1 file changed

+80
-91
lines changed

1 file changed

+80
-91
lines changed

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 80 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,7 @@
1616

1717
namespace facebook::torchcodec {
1818

19-
/*
20-
The VideoDecoder class can be used to decode video frames to Tensors.
21-
22-
Example usage of this class:
23-
std::string video_file_path = "/path/to/video.mp4";
24-
VideoDecoder video_decoder = VideoDecoder::createFromFilePath(video_file_path);
25-
26-
// After creating the decoder, we can query the metadata:
27-
auto metadata = video_decoder.getContainerMetadata();
28-
29-
// We can also add streams to the decoder:
30-
// -1 sets the default stream.
31-
video_decoder.addVideoStreamDecoder(-1);
32-
33-
// API for seeking and frame extraction:
34-
// Let's extract the first frame at or after pts=5.0 seconds.
35-
video_decoder.setCursorPtsInSeconds(5.0);
36-
auto output = video_decoder->getNextFrameOutput();
37-
torch::Tensor frame = output.frame;
38-
double presentation_timestamp = output.ptsSeconds;
39-
// Note that presentation_timestamp can be any timestamp at 5.0 or above
40-
// because the frame time may not align exactly with the seek time.
41-
CHECK_GE(presentation_timestamp, 5.0);
42-
*/
19+
// The VideoDecoder class can be used to decode video frames to Tensors.
4320
// Note that VideoDecoder is not thread-safe.
4421
// Do not call non-const APIs concurrently on the same object.
4522
class VideoDecoder {
@@ -53,16 +30,12 @@ class VideoDecoder {
5330
enum class SeekMode { exact, approximate };
5431

5532
// Creates a VideoDecoder from the video at videoFilePath.
56-
explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode);
57-
58-
// Creates a VideoDecoder from a given buffer. Note that the buffer is not
59-
// owned by the VideoDecoder.
60-
explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode);
61-
6233
static std::unique_ptr<VideoDecoder> createFromFilePath(
6334
const std::string& videoFilePath,
6435
SeekMode seekMode = SeekMode::exact);
6536

37+
// Creates a VideoDecoder from a given buffer. Note that the buffer is not
38+
// owned by the VideoDecoder.
6639
static std::unique_ptr<VideoDecoder> createFromBuffer(
6740
const void* buffer,
6841
size_t length,
@@ -71,8 +44,10 @@ class VideoDecoder {
7144
// --------------------------------------------------------------------------
7245
// VIDEO METADATA QUERY API
7346
// --------------------------------------------------------------------------
47+
7448
// Updates the metadata of the video to accurate values obtained by scanning
75-
// the contents of the video file.
49+
// the contents of the video file. Also updates each StreamInfo's index, i.e.
50+
// the allFrames and keyFrames vectors.
7651
void scanFileAndUpdateMetadataAndIndex();
7752

7853
struct StreamMetadata {
@@ -88,7 +63,6 @@ class VideoDecoder {
8863
std::optional<int64_t> numKeyFrames;
8964
std::optional<double> averageFps;
9065
std::optional<double> bitRate;
91-
std::optional<std::vector<int64_t>> keyFrames;
9266

9367
// More accurate duration, obtained by scanning the file.
9468
// These presentation timestamps are in time base.
@@ -126,6 +100,7 @@ class VideoDecoder {
126100
// --------------------------------------------------------------------------
127101
// ADDING STREAMS API
128102
// --------------------------------------------------------------------------
103+
129104
enum ColorConversionLibrary {
130105
// TODO: Add an AUTO option later.
131106
// Use the libavfilter library for color conversion.
@@ -164,93 +139,71 @@ class VideoDecoder {
164139
int streamIndex,
165140
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());
166141

167-
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
168-
169-
// ---- SINGLE FRAME SEEK AND DECODING API ----
170-
// Places the cursor at the first frame on or after the position in seconds.
171-
// Calling getNextFrameNoDemuxInternal() will return the first frame at
172-
// or after this position.
173-
void setCursorPtsInSeconds(double seconds);
174-
175-
// This structure ensures we always keep the streamIndex and AVFrame together
176-
// Note that AVFrame itself doesn't retain the streamIndex.
177-
struct AVFrameStream {
178-
// The actual decoded output as a unique pointer to an AVFrame.
179-
UniqueAVFrame avFrame;
180-
// The stream index of the decoded frame.
181-
int streamIndex;
182-
};
142+
// --------------------------------------------------------------------------
143+
// DECODING AND SEEKING APIs
144+
// --------------------------------------------------------------------------
183145

146+
// All public decoding entry points return either a FrameOutput or a
147+
// FrameBatchOutput.
148+
// They are the equivalent of the user-facing Frame and FrameBatch classes in
149+
// Python. They contain RGB decoded frames along with some associated data
150+
// like PTS and duration.
184151
struct FrameOutput {
185-
// The actual decoded output as a Tensor.
186-
torch::Tensor data;
187-
// The stream index of the decoded frame. Used to distinguish
188-
// between streams that are of the same type.
152+
torch::Tensor data; // 3D: of shape CHW or HWC.
189153
int streamIndex;
190-
// The presentation timestamp of the decoded frame in seconds.
191154
double ptsSeconds;
192-
// The duration of the decoded frame in seconds.
193155
double durationSeconds;
194156
};
195157

196158
struct FrameBatchOutput {
197-
torch::Tensor data;
198-
torch::Tensor ptsSeconds;
199-
torch::Tensor durationSeconds;
159+
torch::Tensor data; // 4D: of shape NCHW or NHWC.
160+
torch::Tensor ptsSeconds; // 1D of shape (N,)
161+
torch::Tensor durationSeconds; // 1D of shape (N,)
200162

201163
explicit FrameBatchOutput(
202164
int64_t numFrames,
203165
const VideoStreamOptions& videoStreamOptions,
204166
const StreamMetadata& streamMetadata);
205167
};
206168

207-
class EndOfFileException : public std::runtime_error {
208-
public:
209-
explicit EndOfFileException(const std::string& msg)
210-
: std::runtime_error(msg) {}
211-
};
169+
// Places the cursor at the first frame on or after the position in seconds.
170+
// Calling getNextFrameNoDemux() will return the first frame at
171+
// or after this position.
172+
void setCursorPtsInSeconds(double seconds);
212173

213174
// Decodes the frame where the current cursor position is. It also advances
214175
// the cursor to the next frame.
215176
FrameOutput getNextFrameNoDemux();
216-
// Decodes the first frame in any added stream that is visible at a given
217-
// timestamp. Frames in the video have a presentation timestamp and a
218-
// duration. For example, if a frame has presentation timestamp of 5.0s and a
219-
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
220-
// i.e. it will be returned when this function is called with seconds=5.0 or
221-
// seconds=5.999, etc.
222-
FrameOutput getFramePlayedAtNoDemux(double seconds);
223177

224178
FrameOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
225-
// This is morally private but needs to be exposed for C++ tests. Once
226-
// getFrameAtIndex supports the preAllocatedOutputTensor parameter, we can
227-
// move it back to private.
228-
FrameOutput getFrameAtIndexInternal(
229-
int streamIndex,
230-
int64_t frameIndex,
231-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
232179

233180
// Returns frames at the given indices for a given stream as a single stacked
234181
// Tensor.
235182
FrameBatchOutput getFramesAtIndices(
236183
int streamIndex,
237184
const std::vector<int64_t>& frameIndices);
238185

186+
// Returns frames within a given range. The range is defined by [start, stop).
187+
// The values retrieved from the range are: [start, start+step,
188+
// start+(2*step), start+(3*step), ..., stop). The default for step is 1.
189+
FrameBatchOutput
190+
getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step);
191+
192+
// Decodes the first frame in any added stream that is visible at a given
193+
// timestamp. Frames in the video have a presentation timestamp and a
194+
// duration. For example, if a frame has presentation timestamp of 5.0s and a
195+
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
196+
// i.e. it will be returned when this function is called with seconds=5.0 or
197+
// seconds=5.999, etc.
198+
FrameOutput getFramePlayedAtNoDemux(double seconds);
199+
239200
FrameBatchOutput getFramesPlayedAt(
240201
int streamIndex,
241202
const std::vector<double>& timestamps);
242203

243-
// Returns frames within a given range for a given stream as a single stacked
244-
// Tensor. The range is defined by [start, stop). The values retrieved from
245-
// the range are:
246-
// [start, start+step, start+(2*step), start+(3*step), ..., stop)
247-
// The default for step is 1.
248-
FrameBatchOutput
249-
getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step);
250-
251-
// Returns frames within a given pts range for a given stream as a single
252-
// stacked tensor. The range is defined by [startSeconds, stopSeconds) with
253-
// respect to the pts values for frames. The returned frames are in pts order.
204+
// Returns frames within a given pts range. The range is defined by
205+
// [startSeconds, stopSeconds) with respect to the pts values for frames. The
206+
// returned frames are in pts order.
254207
//
255208
// Note that while stopSeconds is excluded in the half open range, this really
256209
// only makes a difference when stopSeconds is exactly the pts value for a
@@ -270,11 +223,44 @@ class VideoDecoder {
270223
double startSeconds,
271224
double stopSeconds);
272225

226+
class EndOfFileException : public std::runtime_error {
227+
public:
228+
explicit EndOfFileException(const std::string& msg)
229+
: std::runtime_error(msg) {}
230+
};
231+
273232
// --------------------------------------------------------------------------
274-
// DECODER PERFORMANCE STATISTICS API
233+
// MORALLY PRIVATE APIS
275234
// --------------------------------------------------------------------------
235+
// These are APIs that should be private, but that are effectively exposed for
236+
// practical reasons, typically for testing purposes.
237+
238+
// This struct is needed because AVFrame doesn't retain the streamIndex. Only
239+
// the AVPacket knows its stream. This is what the low-level private decoding
240+
// entry points return. The AVFrameStream is then converted to a FrameOutput
241+
// with convertAVFrameToFrameOutput. It should be private, but is currently
242+
// used by DeviceInterface.
243+
struct AVFrameStream {
244+
// The actual decoded output as a unique pointer to an AVFrame.
245+
// Usually, this is a YUV frame. It'll be converted to RGB in
246+
// convertAVFrameToFrameOutput.
247+
UniqueAVFrame avFrame;
248+
// The stream index of the decoded frame.
249+
int streamIndex;
250+
};
251+
252+
// Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we
253+
// can move it back to private.
254+
FrameOutput getFrameAtIndexInternal(
255+
int streamIndex,
256+
int64_t frameIndex,
257+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
258+
259+
// Exposed for _test_frame_pts_equality, which is used to test non-regression
260+
// of pts resolution (64 to 32 bit floats)
261+
double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex);
276262

277-
// Only exposed for performance testing.
263+
// Exposed for performance testing.
278264
struct DecodeStats {
279265
int64_t numSeeksAttempted = 0;
280266
int64_t numSeeksDone = 0;
@@ -288,9 +274,11 @@ class VideoDecoder {
288274
DecodeStats getDecodeStats() const;
289275
void resetDecodeStats();
290276

291-
double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex);
292-
293277
private:
278+
explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode);
279+
explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode);
280+
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
281+
294282
struct FrameInfo {
295283
int64_t pts = 0;
296284
// The value of this default is important: the last frame's nextPts will be
@@ -401,6 +389,7 @@ class VideoDecoder {
401389
const enum AVColorSpace colorspace);
402390

403391
void maybeSeekToBeforeDesiredPts();
392+
404393
AVFrameStream getAVFrameUsingFilterFunction(
405394
std::function<bool(int, AVFrame*)>);
406395
// Once we create a decoder can update the metadata with the codec context.

0 commit comments

Comments
 (0)