@@ -563,13 +563,14 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
563563 if (packet->flags & AV_PKT_FLAG_DISCARD) {
564564 continue ;
565565 }
566- auto & stream = containerMetadata_.streams [streamIndex];
567- stream .minPtsFromScan =
568- std::min (stream .minPtsFromScan .value_or (INT64_MAX), packet->pts );
569- stream .maxPtsFromScan = std::max (
570- stream .maxPtsFromScan .value_or (INT64_MIN),
566+ auto & streamMetadata = containerMetadata_.streams [streamIndex];
567+ streamMetadata .minPtsFromScan = std::min (
568+ streamMetadata .minPtsFromScan .value_or (INT64_MAX), packet->pts );
569+ streamMetadata .maxPtsFromScan = std::max (
570+ streamMetadata .maxPtsFromScan .value_or (INT64_MIN),
571571 packet->pts + packet->duration );
572- stream.numFramesFromScan = stream.numFramesFromScan .value_or (0 ) + 1 ;
572+ streamMetadata.numFramesFromScan =
573+ streamMetadata.numFramesFromScan .value_or (0 ) + 1 ;
573574
574575 FrameInfo frameInfo;
575576 frameInfo.pts = packet->pts ;
@@ -579,16 +580,17 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
579580 }
580581 streams_[streamIndex].allFrames .push_back (frameInfo);
581582 }
582- for (size_t i = 0 ; i < containerMetadata_.streams .size (); ++i) {
583- auto & streamMetadata = containerMetadata_.streams [i];
584- auto stream = formatContext_->streams [i];
583+ for (size_t streamIndex = 0 ; streamIndex < containerMetadata_.streams .size ();
584+ ++streamIndex) {
585+ auto & streamMetadata = containerMetadata_.streams [streamIndex];
586+ auto avStream = formatContext_->streams [streamIndex];
585587 if (streamMetadata.minPtsFromScan .has_value ()) {
586588 streamMetadata.minPtsSecondsFromScan =
587- *streamMetadata.minPtsFromScan * av_q2d (stream ->time_base );
589+ *streamMetadata.minPtsFromScan * av_q2d (avStream ->time_base );
588590 }
589591 if (streamMetadata.maxPtsFromScan .has_value ()) {
590592 streamMetadata.maxPtsSecondsFromScan =
591- *streamMetadata.maxPtsFromScan * av_q2d (stream ->time_base );
593+ *streamMetadata.maxPtsFromScan * av_q2d (avStream ->time_base );
592594 }
593595 }
594596 int ffmepgStatus =
@@ -598,23 +600,23 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
598600 " Could not seek file to pts=0: " +
599601 getFFMPEGErrorStringFromErrorCode (ffmepgStatus));
600602 }
601- for (auto & [streamIndex, stream ] : streams_) {
603+ for (auto & [streamIndex, streamInfo ] : streams_) {
602604 std::sort (
603- stream .keyFrames .begin (),
604- stream .keyFrames .end (),
605+ streamInfo .keyFrames .begin (),
606+ streamInfo .keyFrames .end (),
605607 [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
606608 return frameInfo1.pts < frameInfo2.pts ;
607609 });
608610 std::sort (
609- stream .allFrames .begin (),
610- stream .allFrames .end (),
611+ streamInfo .allFrames .begin (),
612+ streamInfo .allFrames .end (),
611613 [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
612614 return frameInfo1.pts < frameInfo2.pts ;
613615 });
614616
615- for (size_t i = 0 ; i < stream .allFrames .size (); ++i) {
616- if (i + 1 < stream .allFrames .size ()) {
617- stream .allFrames [i].nextPts = stream .allFrames [i + 1 ].pts ;
617+ for (size_t i = 0 ; i < streamInfo .allFrames .size (); ++i) {
618+ if (i + 1 < streamInfo .allFrames .size ()) {
619+ streamInfo .allFrames [i].nextPts = streamInfo .allFrames [i + 1 ].pts ;
618620 }
619621 }
620622 }
@@ -870,11 +872,9 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
870872 AVFrame* frame = rawOutput.frame .get ();
871873 output.streamIndex = streamIndex;
872874 auto & streamInfo = streams_[streamIndex];
873- output.streamType = streams_[streamIndex].stream ->codecpar ->codec_type ;
874- output.pts = frame->pts ;
875+ TORCH_CHECK (streamInfo.stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO);
875876 output.ptsSeconds =
876877 ptsToSeconds (frame->pts , formatContext_->streams [streamIndex]->time_base );
877- output.duration = getDuration (frame);
878878 output.durationSeconds = ptsToSeconds (
879879 getDuration (frame), formatContext_->streams [streamIndex]->time_base );
880880 // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
@@ -931,86 +931,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
931931 }
932932
933933 torch::Tensor outputTensor;
934- if (output.streamType == AVMEDIA_TYPE_VIDEO) {
935- // We need to compare the current frame context with our previous frame
936- // context. If they are different, then we need to re-create our colorspace
937- // conversion objects. We create our colorspace conversion objects late so
938- // that we don't have to depend on the unreliable metadata in the header.
939- // And we sometimes re-create them because it's possible for frame
940- // resolution to change mid-stream. Finally, we want to reuse the colorspace
941- // conversion objects as much as possible for performance reasons.
942- enum AVPixelFormat frameFormat =
943- static_cast <enum AVPixelFormat>(frame->format );
944- auto frameContext = DecodedFrameContext{
945- frame->width ,
946- frame->height ,
947- frameFormat,
948- expectedOutputWidth,
949- expectedOutputHeight};
934+ // We need to compare the current frame context with our previous frame
935+ // context. If they are different, then we need to re-create our colorspace
936+ // conversion objects. We create our colorspace conversion objects late so
937+ // that we don't have to depend on the unreliable metadata in the header.
938+ // And we sometimes re-create them because it's possible for frame
939+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
940+ // conversion objects as much as possible for performance reasons.
941+ enum AVPixelFormat frameFormat =
942+ static_cast <enum AVPixelFormat>(frame->format );
943+ auto frameContext = DecodedFrameContext{
944+ frame->width ,
945+ frame->height ,
946+ frameFormat,
947+ expectedOutputWidth,
948+ expectedOutputHeight};
950949
951- if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
952- outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
953- expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
950+ if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
951+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
952+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
954953
955- if (!streamInfo.swsContext ||
956- streamInfo.prevFrameContext != frameContext) {
957- createSwsContext (streamInfo, frameContext, frame->colorspace );
958- streamInfo.prevFrameContext = frameContext;
959- }
960- int resultHeight =
961- convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
962- // If this check failed, it would mean that the frame wasn't reshaped to
963- // the expected height.
964- // TODO: Can we do the same check for width?
965- TORCH_CHECK (
966- resultHeight == expectedOutputHeight,
967- " resultHeight != expectedOutputHeight: " ,
968- resultHeight,
969- " != " ,
970- expectedOutputHeight);
954+ if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
955+ createSwsContext (streamInfo, frameContext, frame->colorspace );
956+ streamInfo.prevFrameContext = frameContext;
957+ }
958+ int resultHeight =
959+ convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
960+ // If this check failed, it would mean that the frame wasn't reshaped to
961+ // the expected height.
962+ // TODO: Can we do the same check for width?
963+ TORCH_CHECK (
964+ resultHeight == expectedOutputHeight,
965+ " resultHeight != expectedOutputHeight: " ,
966+ resultHeight,
967+ " != " ,
968+ expectedOutputHeight);
969+
970+ output.frame = outputTensor;
971+ } else if (
972+ streamInfo.colorConversionLibrary ==
973+ ColorConversionLibrary::FILTERGRAPH) {
974+ if (!streamInfo.filterState .filterGraph ||
975+ streamInfo.prevFrameContext != frameContext) {
976+ createFilterGraph (streamInfo, expectedOutputHeight, expectedOutputWidth);
977+ streamInfo.prevFrameContext = frameContext;
978+ }
979+ outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
971980
972- output.frame = outputTensor;
973- } else if (
974- streamInfo.colorConversionLibrary ==
975- ColorConversionLibrary::FILTERGRAPH) {
976- if (!streamInfo.filterState .filterGraph ||
977- streamInfo.prevFrameContext != frameContext) {
978- createFilterGraph (
979- streamInfo, expectedOutputHeight, expectedOutputWidth);
980- streamInfo.prevFrameContext = frameContext;
981- }
982- outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
983-
984- // Similarly to above, if this check fails it means the frame wasn't
985- // reshaped to its expected dimensions by filtergraph.
986- auto shape = outputTensor.sizes ();
987- TORCH_CHECK (
988- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
989- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
990- " Expected output tensor of shape " ,
991- expectedOutputHeight,
992- " x" ,
993- expectedOutputWidth,
994- " x3, got " ,
995- shape);
996-
997- if (preAllocatedOutputTensor.has_value ()) {
998- // We have already validated that preAllocatedOutputTensor and
999- // outputTensor have the same shape.
1000- preAllocatedOutputTensor.value ().copy_ (outputTensor);
1001- output.frame = preAllocatedOutputTensor.value ();
1002- } else {
1003- output.frame = outputTensor;
1004- }
981+ // Similarly to above, if this check fails it means the frame wasn't
982+ // reshaped to its expected dimensions by filtergraph.
983+ auto shape = outputTensor.sizes ();
984+ TORCH_CHECK (
985+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
986+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
987+ " Expected output tensor of shape " ,
988+ expectedOutputHeight,
989+ " x" ,
990+ expectedOutputWidth,
991+ " x3, got " ,
992+ shape);
993+
994+ if (preAllocatedOutputTensor.has_value ()) {
995+ // We have already validated that preAllocatedOutputTensor and
996+ // outputTensor have the same shape.
997+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
998+ output.frame = preAllocatedOutputTensor.value ();
1005999 } else {
1006- throw std::runtime_error (
1007- " Invalid color conversion library: " +
1008- std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
1000+ output.frame = outputTensor;
10091001 }
1010- } else if (output. streamType == AVMEDIA_TYPE_AUDIO) {
1011- // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1012- // audio decoding.
1013- throw std::runtime_error ( " Audio is not supported yet. " );
1002+ } else {
1003+ throw std::runtime_error (
1004+ " Invalid color conversion library: " +
1005+ std::to_string ( static_cast < int >(streamInfo. colorConversionLibrary )) );
10141006 }
10151007}
10161008
0 commit comments