Skip to content

Commit 3253c8f

Browse files
authored
Reorganize private APIs in videoDecoder.h (#479)
1 parent eaf3dd3 commit 3253c8f

File tree

1 file changed

+116
-71
lines changed

1 file changed

+116
-71
lines changed

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 116 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,9 @@ class VideoDecoder {
281281
void resetDecodeStats();
282282

283283
private:
284-
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
284+
// --------------------------------------------------------------------------
285+
// STREAMINFO AND ASSOCIATED STRUCTS
286+
// --------------------------------------------------------------------------
285287

286288
struct FrameInfo {
287289
int64_t pts = 0;
@@ -309,73 +311,114 @@ class VideoDecoder {
309311
bool operator!=(const DecodedFrameContext&);
310312
};
311313

312-
// Stores information for each stream.
313314
struct StreamInfo {
314315
int streamIndex = -1;
315316
AVStream* stream = nullptr;
316317
AVRational timeBase = {};
317318
UniqueAVCodecContext codecContext;
318-
// The current position of the cursor in the stream.
319+
320+
// The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was
321+
// called.
322+
std::vector<FrameInfo> keyFrames;
323+
std::vector<FrameInfo> allFrames;
324+
325+
// The current position of the cursor in the stream, and associated frame
326+
// duration.
319327
int64_t currentPts = 0;
320328
int64_t currentDuration = 0;
321329
// The desired position of the cursor in the stream. We send frames >=
322330
// this pts to the user when they request a frame.
323-
// We update this field if the user requested a seek.
331+
// We update this field if the user requested a seek. This typically
332+
// corresponds to the decoder's desiredPts_ attribute.
324333
int64_t discardFramesBeforePts = INT64_MIN;
325334
VideoStreamOptions videoStreamOptions;
326-
// The filter state associated with this stream (for video streams). The
327-
// actual graph will be nullptr for inactive streams.
335+
336+
// color-conversion fields. Only one of FilterGraphContext and
337+
// UniqueSwsContext should be non-null.
328338
FilterGraphContext filterGraphContext;
329339
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
330-
std::vector<FrameInfo> keyFrames;
331-
std::vector<FrameInfo> allFrames;
332-
DecodedFrameContext prevFrameContext;
333340
UniqueSwsContext swsContext;
341+
342+
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
343+
// be created before decoding a new frame.
344+
DecodedFrameContext prevFrameContext;
334345
};
335346

336-
// Returns the key frame index of the presentation timestamp using FFMPEG's
337-
// index. Note that this index may be truncated for some files.
338-
int getKeyFrameIndexForPtsUsingEncoderIndex(AVStream* stream, int64_t pts)
339-
const;
340-
// Returns the key frame index of the presentation timestamp using our index.
341-
// We build this index by scanning the file in buildKeyFrameIndex().
342-
int getKeyFrameIndexForPtsUsingScannedIndex(
343-
const std::vector<VideoDecoder::FrameInfo>& keyFrames,
344-
int64_t pts) const;
345-
int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const;
347+
// --------------------------------------------------------------------------
348+
// INITIALIZERS
349+
// --------------------------------------------------------------------------
350+
351+
void initializeDecoder();
352+
void updateMetadataWithCodecContext(
353+
int streamIndex,
354+
AVCodecContext* codecContext);
355+
356+
// --------------------------------------------------------------------------
357+
// DECODING APIS AND RELATED UTILS
358+
// --------------------------------------------------------------------------
359+
346360
bool canWeAvoidSeekingForStream(
347361
const StreamInfo& stream,
348362
int64_t currentPts,
349363
int64_t targetPts) const;
350-
// Returns the "best" stream index for a given media type. The "best" is
351-
// determined by various heuristics in FFMPEG.
352-
// See
353-
// https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2
354-
// for more details about the heuristics.
355-
int getBestStreamIndex(AVMediaType mediaType);
356-
void initializeDecoder();
357-
void validateUserProvidedStreamIndex(int streamIndex);
358-
void validateScannedAllStreams(const std::string& msg);
359-
void validateFrameIndex(
360-
const StreamMetadata& streamMetadata,
361-
int64_t frameIndex);
362364

363-
// Creates and initializes a filter graph for a stream. The filter graph can
364-
// do rescaling and color conversion.
365+
void maybeSeekToBeforeDesiredPts();
366+
367+
AVFrameStream decodeAVFrame(
368+
std::function<bool(int, AVFrame*)> filterFunction);
369+
370+
FrameOutput getNextFrameNoDemuxInternal(
371+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
372+
373+
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
374+
375+
FrameOutput convertAVFrameToFrameOutput(
376+
AVFrameStream& avFrameStream,
377+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
378+
379+
void convertAVFrameToFrameOutputOnCPU(
380+
AVFrameStream& avFrameStream,
381+
FrameOutput& frameOutput,
382+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
383+
384+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
385+
int streamIndex,
386+
const AVFrame* avFrame);
387+
388+
int convertAVFrameToTensorUsingSwsScale(
389+
int streamIndex,
390+
const AVFrame* avFrame,
391+
torch::Tensor& outputTensor);
392+
393+
// --------------------------------------------------------------------------
394+
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
395+
// --------------------------------------------------------------------------
396+
365397
void createFilterGraph(
366398
StreamInfo& streamInfo,
367399
int expectedOutputHeight,
368400
int expectedOutputWidth);
369401

370-
int64_t getNumFrames(const StreamMetadata& streamMetadata);
402+
void createSwsContext(
403+
StreamInfo& streamInfo,
404+
const DecodedFrameContext& frameContext,
405+
const enum AVColorSpace colorspace);
371406

372-
int64_t getPts(
373-
const StreamInfo& streamInfo,
374-
const StreamMetadata& streamMetadata,
375-
int64_t frameIndex);
407+
// --------------------------------------------------------------------------
408+
// PTS <-> INDEX CONVERSIONS
409+
// --------------------------------------------------------------------------
376410

377-
double getMinSeconds(const StreamMetadata& streamMetadata);
378-
double getMaxSeconds(const StreamMetadata& streamMetadata);
411+
int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const;
412+
413+
// Returns the key frame index of the presentation timestamp using our index.
414+
// We build this index by scanning the file in
415+
// scanFileAndUpdateMetadataAndIndex
416+
int getKeyFrameIndexForPtsUsingScannedIndex(
417+
const std::vector<VideoDecoder::FrameInfo>& keyFrames,
418+
int64_t pts) const;
419+
// Return key frame index, from FFmpeg. Potentially less accurate
420+
int getKeyFrameIndexForPtsUsingEncoderIndex(AVStream* stream, int64_t pts)
421+
const;
379422

380423
int64_t secondsToIndexLowerBound(
381424
double seconds,
@@ -387,40 +430,43 @@ class VideoDecoder {
387430
const StreamInfo& streamInfo,
388431
const StreamMetadata& streamMetadata);
389432

390-
void createSwsContext(
391-
StreamInfo& streamInfo,
392-
const DecodedFrameContext& frameContext,
393-
const enum AVColorSpace colorspace);
394-
395-
void maybeSeekToBeforeDesiredPts();
433+
int64_t getPts(
434+
const StreamInfo& streamInfo,
435+
const StreamMetadata& streamMetadata,
436+
int64_t frameIndex);
396437

397-
AVFrameStream decodeAVFrame(
398-
std::function<bool(int, AVFrame*)> filterFunction);
438+
// --------------------------------------------------------------------------
439+
// STREAM AND METADATA APIS
440+
// --------------------------------------------------------------------------
399441

400-
// Once we create a decoder can update the metadata with the codec context.
401-
// For example, for video streams, we can add the height and width of the
402-
// decoded stream.
403-
void updateMetadataWithCodecContext(
404-
int streamIndex,
405-
AVCodecContext* codecContext);
406442
void populateVideoMetadataFromStreamIndex(int streamIndex);
407-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
408-
int streamIndex,
409-
const AVFrame* avFrame);
410-
int convertAVFrameToTensorUsingSwsScale(
411-
int streamIndex,
412-
const AVFrame* avFrame,
413-
torch::Tensor& outputTensor);
414-
FrameOutput convertAVFrameToFrameOutput(
415-
AVFrameStream& avFrameStream,
416-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
417-
void convertAVFrameToFrameOutputOnCPU(
418-
AVFrameStream& avFrameStream,
419-
FrameOutput& frameOutput,
420-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
421443

422-
FrameOutput getNextFrameNoDemuxInternal(
423-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
444+
// Returns the "best" stream index for a given media type. The "best" is
445+
// determined by various heuristics in FFMPEG.
446+
// See
447+
// https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2
448+
// for more details about the heuristics.
449+
// Returns the key frame index of the presentation timestamp using FFMPEG's
450+
// index. Note that this index may be truncated for some files.
451+
int getBestStreamIndex(AVMediaType mediaType);
452+
453+
int64_t getNumFrames(const StreamMetadata& streamMetadata);
454+
double getMinSeconds(const StreamMetadata& streamMetadata);
455+
double getMaxSeconds(const StreamMetadata& streamMetadata);
456+
457+
// --------------------------------------------------------------------------
458+
// VALIDATION UTILS
459+
// --------------------------------------------------------------------------
460+
461+
void validateUserProvidedStreamIndex(int streamIndex);
462+
void validateScannedAllStreams(const std::string& msg);
463+
void validateFrameIndex(
464+
const StreamMetadata& streamMetadata,
465+
int64_t frameIndex);
466+
467+
// --------------------------------------------------------------------------
468+
// ATTRIBUTES
469+
// --------------------------------------------------------------------------
424470

425471
SeekMode seekMode_;
426472
ContainerMetadata containerMetadata_;
@@ -432,7 +478,6 @@ class VideoDecoder {
432478
// Set when the user wants to seek and stores the desired pts that the user
433479
// wants to seek to.
434480
std::optional<double> desiredPtsSeconds_;
435-
436481
// Stores various internal decoding stats.
437482
DecodeStats decodeStats_;
438483
// Stores the AVIOContext for the input buffer.

0 commit comments

Comments
 (0)