@@ -600,6 +600,8 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
600600 streamMetadata.maxPtsFromScan = std::max (
601601 streamMetadata.maxPtsFromScan .value_or (INT64_MIN),
602602 packet->pts + packet->duration );
603+ streamMetadata.numFramesFromScan =
604+ streamMetadata.numFramesFromScan .value_or (0 ) + 1 ;
603605
604606 // Note that we set the other value in this struct, nextPts, only after
605607 // we have scanned all packets and sorted by pts.
@@ -612,19 +614,20 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
612614
613615 // Set all per-stream metadata that requires knowing the content of all
614616 // packets.
615- for (size_t i = 0 ; i < containerMetadata_.streams .size (); ++i) {
616- auto & streamMetadata = containerMetadata_.streams [i];
617- auto stream = formatContext_->streams [i];
617+ for (size_t streamIndex = 0 ; streamIndex < containerMetadata_.streams .size ();
618+ ++streamIndex) {
619+ auto & streamMetadata = containerMetadata_.streams [streamIndex];
620+ auto avStream = formatContext_->streams [streamIndex];
618621
619- streamMetadata.numFramesFromScan = streams_[i ].allFrames .size ();
622+ streamMetadata.numFramesFromScan = streams_[streamIndex ].allFrames .size ();
620623
621624 if (streamMetadata.minPtsFromScan .has_value ()) {
622625 streamMetadata.minPtsSecondsFromScan =
623- *streamMetadata.minPtsFromScan * av_q2d (stream ->time_base );
626+ *streamMetadata.minPtsFromScan * av_q2d (avStream ->time_base );
624627 }
625628 if (streamMetadata.maxPtsFromScan .has_value ()) {
626629 streamMetadata.maxPtsSecondsFromScan =
627- *streamMetadata.maxPtsFromScan * av_q2d (stream ->time_base );
630+ *streamMetadata.maxPtsFromScan * av_q2d (avStream ->time_base );
628631 }
629632 }
630633
@@ -638,23 +641,23 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
638641 }
639642
640643 // Sort all frames by their pts.
641- for (auto & [streamIndex, stream ] : streams_) {
644+ for (auto & [streamIndex, streamInfo ] : streams_) {
642645 std::sort (
643- stream .keyFrames .begin (),
644- stream .keyFrames .end (),
646+ streamInfo .keyFrames .begin (),
647+ streamInfo .keyFrames .end (),
645648 [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
646649 return frameInfo1.pts < frameInfo2.pts ;
647650 });
648651 std::sort (
649- stream .allFrames .begin (),
650- stream .allFrames .end (),
652+ streamInfo .allFrames .begin (),
653+ streamInfo .allFrames .end (),
651654 [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
652655 return frameInfo1.pts < frameInfo2.pts ;
653656 });
654657
655- for (size_t i = 0 ; i < stream .allFrames .size (); ++i) {
656- if (i + 1 < stream .allFrames .size ()) {
657- stream .allFrames [i].nextPts = stream .allFrames [i + 1 ].pts ;
658+ for (size_t i = 0 ; i < streamInfo .allFrames .size (); ++i) {
659+ if (i + 1 < streamInfo .allFrames .size ()) {
660+ streamInfo .allFrames [i].nextPts = streamInfo .allFrames [i + 1 ].pts ;
658661 }
659662 }
660663 }
@@ -911,11 +914,9 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
911914 AVFrame* frame = rawOutput.frame .get ();
912915 output.streamIndex = streamIndex;
913916 auto & streamInfo = streams_[streamIndex];
914- output.streamType = streams_[streamIndex].stream ->codecpar ->codec_type ;
915- output.pts = frame->pts ;
917+ TORCH_CHECK (streamInfo.stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO);
916918 output.ptsSeconds =
917919 ptsToSeconds (frame->pts , formatContext_->streams [streamIndex]->time_base );
918- output.duration = getDuration (frame);
919920 output.durationSeconds = ptsToSeconds (
920921 getDuration (frame), formatContext_->streams [streamIndex]->time_base );
921922 // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
@@ -972,86 +973,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
972973 }
973974
974975 torch::Tensor outputTensor;
975- if (output.streamType == AVMEDIA_TYPE_VIDEO) {
976- // We need to compare the current frame context with our previous frame
977- // context. If they are different, then we need to re-create our colorspace
978- // conversion objects. We create our colorspace conversion objects late so
979- // that we don't have to depend on the unreliable metadata in the header.
980- // And we sometimes re-create them because it's possible for frame
981- // resolution to change mid-stream. Finally, we want to reuse the colorspace
982- // conversion objects as much as possible for performance reasons.
983- enum AVPixelFormat frameFormat =
984- static_cast <enum AVPixelFormat>(frame->format );
985- auto frameContext = DecodedFrameContext{
986- frame->width ,
987- frame->height ,
988- frameFormat,
989- expectedOutputWidth,
990- expectedOutputHeight};
976+ // We need to compare the current frame context with our previous frame
977+ // context. If they are different, then we need to re-create our colorspace
978+ // conversion objects. We create our colorspace conversion objects late so
979+ // that we don't have to depend on the unreliable metadata in the header.
980+ // And we sometimes re-create them because it's possible for frame
981+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
982+ // conversion objects as much as possible for performance reasons.
983+ enum AVPixelFormat frameFormat =
984+ static_cast <enum AVPixelFormat>(frame->format );
985+ auto frameContext = DecodedFrameContext{
986+ frame->width ,
987+ frame->height ,
988+ frameFormat,
989+ expectedOutputWidth,
990+ expectedOutputHeight};
991991
992- if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
993- outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
994- expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
992+ if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
993+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
994+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
995995
996- if (!streamInfo.swsContext ||
997- streamInfo.prevFrameContext != frameContext) {
998- createSwsContext (streamInfo, frameContext, frame->colorspace );
999- streamInfo.prevFrameContext = frameContext;
1000- }
1001- int resultHeight =
1002- convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
1003- // If this check failed, it would mean that the frame wasn't reshaped to
1004- // the expected height.
1005- // TODO: Can we do the same check for width?
1006- TORCH_CHECK (
1007- resultHeight == expectedOutputHeight,
1008- " resultHeight != expectedOutputHeight: " ,
1009- resultHeight,
1010- " != " ,
1011- expectedOutputHeight);
996+ if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
997+ createSwsContext (streamInfo, frameContext, frame->colorspace );
998+ streamInfo.prevFrameContext = frameContext;
999+ }
1000+ int resultHeight =
1001+ convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
1002+ // If this check failed, it would mean that the frame wasn't reshaped to
1003+ // the expected height.
1004+ // TODO: Can we do the same check for width?
1005+ TORCH_CHECK (
1006+ resultHeight == expectedOutputHeight,
1007+ " resultHeight != expectedOutputHeight: " ,
1008+ resultHeight,
1009+ " != " ,
1010+ expectedOutputHeight);
1011+
1012+ output.frame = outputTensor;
1013+ } else if (
1014+ streamInfo.colorConversionLibrary ==
1015+ ColorConversionLibrary::FILTERGRAPH) {
1016+ if (!streamInfo.filterState .filterGraph ||
1017+ streamInfo.prevFrameContext != frameContext) {
1018+ createFilterGraph (streamInfo, expectedOutputHeight, expectedOutputWidth);
1019+ streamInfo.prevFrameContext = frameContext;
1020+ }
1021+ outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
10121022
1013- output.frame = outputTensor;
1014- } else if (
1015- streamInfo.colorConversionLibrary ==
1016- ColorConversionLibrary::FILTERGRAPH) {
1017- if (!streamInfo.filterState .filterGraph ||
1018- streamInfo.prevFrameContext != frameContext) {
1019- createFilterGraph (
1020- streamInfo, expectedOutputHeight, expectedOutputWidth);
1021- streamInfo.prevFrameContext = frameContext;
1022- }
1023- outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
1024-
1025- // Similarly to above, if this check fails it means the frame wasn't
1026- // reshaped to its expected dimensions by filtergraph.
1027- auto shape = outputTensor.sizes ();
1028- TORCH_CHECK (
1029- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
1030- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
1031- " Expected output tensor of shape " ,
1032- expectedOutputHeight,
1033- " x" ,
1034- expectedOutputWidth,
1035- " x3, got " ,
1036- shape);
1037-
1038- if (preAllocatedOutputTensor.has_value ()) {
1039- // We have already validated that preAllocatedOutputTensor and
1040- // outputTensor have the same shape.
1041- preAllocatedOutputTensor.value ().copy_ (outputTensor);
1042- output.frame = preAllocatedOutputTensor.value ();
1043- } else {
1044- output.frame = outputTensor;
1045- }
1023+ // Similarly to above, if this check fails it means the frame wasn't
1024+ // reshaped to its expected dimensions by filtergraph.
1025+ auto shape = outputTensor.sizes ();
1026+ TORCH_CHECK (
1027+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
1028+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
1029+ " Expected output tensor of shape " ,
1030+ expectedOutputHeight,
1031+ " x" ,
1032+ expectedOutputWidth,
1033+ " x3, got " ,
1034+ shape);
1035+
1036+ if (preAllocatedOutputTensor.has_value ()) {
1037+ // We have already validated that preAllocatedOutputTensor and
1038+ // outputTensor have the same shape.
1039+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
1040+ output.frame = preAllocatedOutputTensor.value ();
10461041 } else {
1047- throw std::runtime_error (
1048- " Invalid color conversion library: " +
1049- std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
1042+ output.frame = outputTensor;
10501043 }
1051- } else if (output. streamType == AVMEDIA_TYPE_AUDIO) {
1052- // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1053- // audio decoding.
1054- throw std::runtime_error ( " Audio is not supported yet. " );
1044+ } else {
1045+ throw std::runtime_error (
1046+ " Invalid color conversion library: " +
1047+ std::to_string ( static_cast < int >(streamInfo. colorConversionLibrary )) );
10551048 }
10561049}
10571050
0 commit comments