@@ -583,9 +583,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
583583VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal (
584584 std::optional<torch::Tensor> preAllocatedOutputTensor) {
585585 validateActiveStream ();
586- AVFrameStream avFrameStream = decodeAVFrame (
587- [this ](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
588- return convertAVFrameToFrameOutput (avFrameStream , preAllocatedOutputTensor);
586+ UniqueAVFrame avFrame = decodeAVFrame (
587+ [this ](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
588+ return convertAVFrameToFrameOutput (avFrame , preAllocatedOutputTensor);
589589}
590590
591591VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex (int64_t frameIndex) {
@@ -715,8 +715,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
715715 }
716716
717717 setCursorPtsInSeconds (seconds);
718- AVFrameStream avFrameStream =
719- decodeAVFrame ([seconds, this ](AVFrame* avFrame) {
718+ UniqueAVFrame avFrame =
719+ decodeAVFrame ([seconds, this ](const UniqueAVFrame& avFrame) {
720720 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
721721 double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
722722 double frameEndTime = ptsToSeconds (
@@ -735,7 +735,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
735735 });
736736
737737 // Convert the frame to tensor.
738- FrameOutput frameOutput = convertAVFrameToFrameOutput (avFrameStream );
738+ FrameOutput frameOutput = convertAVFrameToFrameOutput (avFrame );
739739 frameOutput.data = maybePermuteHWC2CHW (frameOutput.data );
740740 return frameOutput;
741741}
@@ -891,14 +891,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
891891 auto finished = false ;
892892 while (!finished) {
893893 try {
894- AVFrameStream avFrameStream = decodeAVFrame ([startPts](AVFrame* avFrame) {
895- return startPts < avFrame->pts + getDuration (avFrame);
896- });
897- // TODO: it's not great that we are getting a FrameOutput, which is
898- // intended for videos. We should consider bypassing
899- // convertAVFrameToFrameOutput and directly call
900- // convertAudioAVFrameToFrameOutputOnCPU.
901- auto frameOutput = convertAVFrameToFrameOutput (avFrameStream);
894+ UniqueAVFrame avFrame =
895+ decodeAVFrame ([startPts](const UniqueAVFrame& avFrame) {
896+ return startPts < avFrame->pts + getDuration (avFrame);
897+ });
898+ auto frameOutput = convertAVFrameToFrameOutput (avFrame);
902899 firstFramePtsSeconds =
903900 std::min (firstFramePtsSeconds, frameOutput.ptsSeconds );
904901 frames.push_back (frameOutput.data );
@@ -1035,8 +1032,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
10351032// LOW-LEVEL DECODING
10361033// --------------------------------------------------------------------------
10371034
1038- VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
1039- std::function<bool (AVFrame* )> filterFunction) {
1035+ UniqueAVFrame VideoDecoder::decodeAVFrame (
1036+ std::function<bool (const UniqueAVFrame& )> filterFunction) {
10401037 validateActiveStream ();
10411038
10421039 resetDecodeStats ();
@@ -1064,7 +1061,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
10641061
10651062 decodeStats_.numFramesReceivedByDecoder ++;
10661063 // Is this the kind of frame we're looking for?
1067- if (status == AVSUCCESS && filterFunction (avFrame. get () )) {
1064+ if (status == AVSUCCESS && filterFunction (avFrame)) {
10681065 // Yes, this is the frame we'll return; break out of the decoding loop.
10691066 break ;
10701067 } else if (status == AVSUCCESS) {
@@ -1150,37 +1147,35 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
11501147 streamInfo.lastDecodedAvFramePts = avFrame->pts ;
11511148 streamInfo.lastDecodedAvFrameDuration = getDuration (avFrame);
11521149
1153- return AVFrameStream ( std::move ( avFrame), activeStreamIndex_) ;
1150+ return avFrame;
11541151}
11551152
11561153// --------------------------------------------------------------------------
11571154// AVFRAME <-> FRAME OUTPUT CONVERSION
11581155// --------------------------------------------------------------------------
11591156
11601157VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
1161- VideoDecoder::AVFrameStream& avFrameStream ,
1158+ UniqueAVFrame& avFrame ,
11621159 std::optional<torch::Tensor> preAllocatedOutputTensor) {
11631160 // Convert the frame to tensor.
11641161 FrameOutput frameOutput;
1165- int streamIndex = avFrameStream.streamIndex ;
1166- AVFrame* avFrame = avFrameStream.avFrame .get ();
1167- frameOutput.streamIndex = streamIndex;
1168- auto & streamInfo = streamInfos_[streamIndex];
1162+ auto & streamInfo = streamInfos_[activeStreamIndex_];
11691163 frameOutput.ptsSeconds = ptsToSeconds (
1170- avFrame->pts , formatContext_->streams [streamIndex ]->time_base );
1164+ avFrame->pts , formatContext_->streams [activeStreamIndex_ ]->time_base );
11711165 frameOutput.durationSeconds = ptsToSeconds (
1172- getDuration (avFrame), formatContext_->streams [streamIndex]->time_base );
1166+ getDuration (avFrame),
1167+ formatContext_->streams [activeStreamIndex_]->time_base );
11731168 if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
11741169 convertAudioAVFrameToFrameOutputOnCPU (
1175- avFrameStream , frameOutput, preAllocatedOutputTensor);
1170+ avFrame , frameOutput, preAllocatedOutputTensor);
11761171 } else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
11771172 convertAVFrameToFrameOutputOnCPU (
1178- avFrameStream , frameOutput, preAllocatedOutputTensor);
1173+ avFrame , frameOutput, preAllocatedOutputTensor);
11791174 } else if (streamInfo.videoStreamOptions .device .type () == torch::kCUDA ) {
11801175 convertAVFrameToFrameOutputOnCuda (
11811176 streamInfo.videoStreamOptions .device ,
11821177 streamInfo.videoStreamOptions ,
1183- avFrameStream ,
1178+ avFrame ,
11841179 frameOutput,
11851180 preAllocatedOutputTensor);
11861181 } else {
@@ -1201,14 +1196,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12011196// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
12021197// `dimension_order` parameter. It's up to callers to re-shape it if needed.
12031198void VideoDecoder::convertAVFrameToFrameOutputOnCPU (
1204- VideoDecoder::AVFrameStream& avFrameStream ,
1199+ UniqueAVFrame& avFrame ,
12051200 FrameOutput& frameOutput,
12061201 std::optional<torch::Tensor> preAllocatedOutputTensor) {
1207- AVFrame* avFrame = avFrameStream.avFrame .get ();
12081202 auto & streamInfo = streamInfos_[activeStreamIndex_];
12091203
12101204 auto frameDims = getHeightAndWidthFromOptionsOrAVFrame (
1211- streamInfo.videoStreamOptions , * avFrame);
1205+ streamInfo.videoStreamOptions , avFrame);
12121206 int expectedOutputHeight = frameDims.height ;
12131207 int expectedOutputWidth = frameDims.width ;
12141208
@@ -1302,7 +1296,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
13021296}
13031297
13041298int VideoDecoder::convertAVFrameToTensorUsingSwsScale (
1305- const AVFrame* avFrame,
1299+ const UniqueAVFrame& avFrame,
13061300 torch::Tensor& outputTensor) {
13071301 StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
13081302 SwsContext* swsContext = activeStreamInfo.swsContext .get ();
@@ -1322,11 +1316,11 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
13221316}
13231317
13241318torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph (
1325- const AVFrame* avFrame) {
1319+ const UniqueAVFrame& avFrame) {
13261320 FilterGraphContext& filterGraphContext =
13271321 streamInfos_[activeStreamIndex_].filterGraphContext ;
13281322 int status =
1329- av_buffersrc_write_frame (filterGraphContext.sourceContext , avFrame);
1323+ av_buffersrc_write_frame (filterGraphContext.sourceContext , avFrame. get () );
13301324 if (status < AVSUCCESS) {
13311325 throw std::runtime_error (" Failed to add frame to buffer source context" );
13321326 }
@@ -1350,25 +1344,25 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13501344}
13511345
13521346void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1353- VideoDecoder::AVFrameStream& avFrameStream ,
1347+ UniqueAVFrame& srcAVFrame ,
13541348 FrameOutput& frameOutput,
13551349 std::optional<torch::Tensor> preAllocatedOutputTensor) {
13561350 TORCH_CHECK (
13571351 !preAllocatedOutputTensor.has_value (),
13581352 " pre-allocated audio tensor not supported yet." );
13591353
13601354 AVSampleFormat sourceSampleFormat =
1361- static_cast <AVSampleFormat>(avFrameStream. avFrame ->format );
1355+ static_cast <AVSampleFormat>(srcAVFrame ->format );
13621356 AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13631357
13641358 UniqueAVFrame convertedAVFrame;
13651359 if (sourceSampleFormat != desiredSampleFormat) {
13661360 convertedAVFrame = convertAudioAVFrameSampleFormat (
1367- avFrameStream. avFrame , sourceSampleFormat, desiredSampleFormat);
1361+ srcAVFrame , sourceSampleFormat, desiredSampleFormat);
13681362 }
13691363 const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
13701364 ? convertedAVFrame
1371- : avFrameStream. avFrame ;
1365+ : srcAVFrame ;
13721366
13731367 AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
13741368 TORCH_CHECK (
@@ -1944,10 +1938,10 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
19441938
19451939FrameDims getHeightAndWidthFromOptionsOrAVFrame (
19461940 const VideoDecoder::VideoStreamOptions& videoStreamOptions,
1947- const AVFrame & avFrame) {
1941+ const UniqueAVFrame & avFrame) {
19481942 return FrameDims (
1949- videoStreamOptions.height .value_or (avFrame. height ),
1950- videoStreamOptions.width .value_or (avFrame. width ));
1943+ videoStreamOptions.height .value_or (avFrame-> height ),
1944+ videoStreamOptions.width .value_or (avFrame-> width ));
19511945}
19521946
19531947} // namespace facebook::torchcodec
0 commit comments