@@ -281,7 +281,9 @@ class VideoDecoder {
281281 void resetDecodeStats ();
282282
283283 private:
284- torch::Tensor maybePermuteHWC2CHW (int streamIndex, torch::Tensor& hwcTensor);
284+ // --------------------------------------------------------------------------
285+ // STREAMINFO AND ASSOCIATED STRUCTS
286+ // --------------------------------------------------------------------------
285287
286288 struct FrameInfo {
287289 int64_t pts = 0 ;
@@ -309,73 +311,114 @@ class VideoDecoder {
309311 bool operator !=(const DecodedFrameContext&);
310312 };
311313
312- // Stores information for each stream.
313314 struct StreamInfo {
314315 int streamIndex = -1 ;
315316 AVStream* stream = nullptr ;
316317 AVRational timeBase = {};
317318 UniqueAVCodecContext codecContext;
318- // The current position of the cursor in the stream.
319+
320+ // The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was
321+ // called.
322+ std::vector<FrameInfo> keyFrames;
323+ std::vector<FrameInfo> allFrames;
324+
325+ // The current position of the cursor in the stream, and associated frame
326+ // duration.
319327 int64_t currentPts = 0 ;
320328 int64_t currentDuration = 0 ;
321329 // The desired position of the cursor in the stream. We send frames >=
322330 // this pts to the user when they request a frame.
323- // We update this field if the user requested a seek.
331+ // We update this field if the user requested a seek. This typically
332+ // corresponds to the decoder's desiredPts_ attribute.
324333 int64_t discardFramesBeforePts = INT64_MIN;
325334 VideoStreamOptions videoStreamOptions;
326- // The filter state associated with this stream (for video streams). The
327- // actual graph will be nullptr for inactive streams.
335+
336+ // color-conversion fields. Only one of FilterGraphContext and
337+ // UniqueSwsContext should be non-null.
328338 FilterGraphContext filterGraphContext;
329339 ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
330- std::vector<FrameInfo> keyFrames;
331- std::vector<FrameInfo> allFrames;
332- DecodedFrameContext prevFrameContext;
333340 UniqueSwsContext swsContext;
341+
342+ // Used to know whether a new FilterGraphContext or UniqueSwsContext should
343+ // be created before decoding a new frame.
344+ DecodedFrameContext prevFrameContext;
334345 };
335346
336- // Returns the key frame index of the presentation timestamp using FFMPEG's
337- // index. Note that this index may be truncated for some files.
338- int getKeyFrameIndexForPtsUsingEncoderIndex (AVStream* stream, int64_t pts)
339- const ;
340- // Returns the key frame index of the presentation timestamp using our index.
341- // We build this index by scanning the file in buildKeyFrameIndex().
342- int getKeyFrameIndexForPtsUsingScannedIndex (
343- const std::vector<VideoDecoder::FrameInfo>& keyFrames,
344- int64_t pts) const ;
345- int getKeyFrameIndexForPts (const StreamInfo& stream, int64_t pts) const ;
347+ // --------------------------------------------------------------------------
348+ // INITIALIZERS
349+ // --------------------------------------------------------------------------
350+
351+ void initializeDecoder ();
352+ void updateMetadataWithCodecContext (
353+ int streamIndex,
354+ AVCodecContext* codecContext);
355+
356+ // --------------------------------------------------------------------------
357+ // DECODING APIS AND RELATED UTILS
358+ // --------------------------------------------------------------------------
359+
346360 bool canWeAvoidSeekingForStream (
347361 const StreamInfo& stream,
348362 int64_t currentPts,
349363 int64_t targetPts) const ;
350- // Returns the "best" stream index for a given media type. The "best" is
351- // determined by various heuristics in FFMPEG.
352- // See
353- // https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2
354- // for more details about the heuristics.
355- int getBestStreamIndex (AVMediaType mediaType);
356- void initializeDecoder ();
357- void validateUserProvidedStreamIndex (int streamIndex);
358- void validateScannedAllStreams (const std::string& msg);
359- void validateFrameIndex (
360- const StreamMetadata& streamMetadata,
361- int64_t frameIndex);
362364
363- // Creates and initializes a filter graph for a stream. The filter graph can
364- // do rescaling and color conversion.
365+ void maybeSeekToBeforeDesiredPts ();
366+
367+ AVFrameStream decodeAVFrame (
368+ std::function<bool (int , AVFrame*)> filterFunction);
369+
370+ FrameOutput getNextFrameNoDemuxInternal (
371+ std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt );
372+
373+ torch::Tensor maybePermuteHWC2CHW (int streamIndex, torch::Tensor& hwcTensor);
374+
375+ FrameOutput convertAVFrameToFrameOutput (
376+ AVFrameStream& avFrameStream,
377+ std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt );
378+
379+ void convertAVFrameToFrameOutputOnCPU (
380+ AVFrameStream& avFrameStream,
381+ FrameOutput& frameOutput,
382+ std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt );
383+
384+ torch::Tensor convertAVFrameToTensorUsingFilterGraph (
385+ int streamIndex,
386+ const AVFrame* avFrame);
387+
388+ int convertAVFrameToTensorUsingSwsScale (
389+ int streamIndex,
390+ const AVFrame* avFrame,
391+ torch::Tensor& outputTensor);
392+
393+ // --------------------------------------------------------------------------
394+ // COLOR CONVERSION LIBRARIES HANDLERS CREATION
395+ // --------------------------------------------------------------------------
396+
365397 void createFilterGraph (
366398 StreamInfo& streamInfo,
367399 int expectedOutputHeight,
368400 int expectedOutputWidth);
369401
370- int64_t getNumFrames (const StreamMetadata& streamMetadata);
402+ void createSwsContext (
403+ StreamInfo& streamInfo,
404+ const DecodedFrameContext& frameContext,
405+ const enum AVColorSpace colorspace);
371406
372- int64_t getPts (
373- const StreamInfo& streamInfo,
374- const StreamMetadata& streamMetadata,
375- int64_t frameIndex);
407+ // --------------------------------------------------------------------------
408+ // PTS <-> INDEX CONVERSIONS
409+ // --------------------------------------------------------------------------
376410
377- double getMinSeconds (const StreamMetadata& streamMetadata);
378- double getMaxSeconds (const StreamMetadata& streamMetadata);
411+ int getKeyFrameIndexForPts (const StreamInfo& stream, int64_t pts) const ;
412+
413+ // Returns the key frame index of the presentation timestamp using our index.
414+ // We build this index by scanning the file in
415+ // scanFileAndUpdateMetadataAndIndex
416+ int getKeyFrameIndexForPtsUsingScannedIndex (
417+ const std::vector<VideoDecoder::FrameInfo>& keyFrames,
418+ int64_t pts) const ;
419+ // Return key frame index, from FFmpeg. Potentially less accurate
420+ int getKeyFrameIndexForPtsUsingEncoderIndex (AVStream* stream, int64_t pts)
421+ const ;
379422
380423 int64_t secondsToIndexLowerBound (
381424 double seconds,
@@ -387,40 +430,43 @@ class VideoDecoder {
387430 const StreamInfo& streamInfo,
388431 const StreamMetadata& streamMetadata);
389432
390- void createSwsContext (
391- StreamInfo& streamInfo,
392- const DecodedFrameContext& frameContext,
393- const enum AVColorSpace colorspace);
394-
395- void maybeSeekToBeforeDesiredPts ();
433+ int64_t getPts (
434+ const StreamInfo& streamInfo,
435+ const StreamMetadata& streamMetadata,
436+ int64_t frameIndex);
396437
397- AVFrameStream decodeAVFrame (
398- std::function<bool (int , AVFrame*)> filterFunction);
438+ // --------------------------------------------------------------------------
439+ // STREAM AND METADATA APIS
440+ // --------------------------------------------------------------------------
399441
400- // Once we create a decoder can update the metadata with the codec context.
401- // For example, for video streams, we can add the height and width of the
402- // decoded stream.
403- void updateMetadataWithCodecContext (
404- int streamIndex,
405- AVCodecContext* codecContext);
406442 void populateVideoMetadataFromStreamIndex (int streamIndex);
407- torch::Tensor convertAVFrameToTensorUsingFilterGraph (
408- int streamIndex,
409- const AVFrame* avFrame);
410- int convertAVFrameToTensorUsingSwsScale (
411- int streamIndex,
412- const AVFrame* avFrame,
413- torch::Tensor& outputTensor);
414- FrameOutput convertAVFrameToFrameOutput (
415- AVFrameStream& avFrameStream,
416- std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt );
417- void convertAVFrameToFrameOutputOnCPU (
418- AVFrameStream& avFrameStream,
419- FrameOutput& frameOutput,
420- std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt );
421443
422- FrameOutput getNextFrameNoDemuxInternal (
423- std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt );
444+ // Returns the "best" stream index for a given media type. The "best" is
445+ // determined by various heuristics in FFMPEG.
446+ // See
447+ // https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2
448+ // for more details about the heuristics.
449+ // Returns the key frame index of the presentation timestamp using FFMPEG's
450+ // index. Note that this index may be truncated for some files.
451+ int getBestStreamIndex (AVMediaType mediaType);
452+
453+ int64_t getNumFrames (const StreamMetadata& streamMetadata);
454+ double getMinSeconds (const StreamMetadata& streamMetadata);
455+ double getMaxSeconds (const StreamMetadata& streamMetadata);
456+
457+ // --------------------------------------------------------------------------
458+ // VALIDATION UTILS
459+ // --------------------------------------------------------------------------
460+
461+ void validateUserProvidedStreamIndex (int streamIndex);
462+ void validateScannedAllStreams (const std::string& msg);
463+ void validateFrameIndex (
464+ const StreamMetadata& streamMetadata,
465+ int64_t frameIndex);
466+
467+ // --------------------------------------------------------------------------
468+ // ATTRIBUTES
469+ // --------------------------------------------------------------------------
424470
425471 SeekMode seekMode_;
426472 ContainerMetadata containerMetadata_;
@@ -432,7 +478,6 @@ class VideoDecoder {
432478 // Set when the user wants to seek and stores the desired pts that the user
433479 // wants to seek to.
434480 std::optional<double > desiredPtsSeconds_;
435-
436481 // Stores various internal decoding stats.
437482 DecodeStats decodeStats_;
438483 // Stores the AVIOContext for the input buffer.
0 commit comments