@@ -869,7 +869,6 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
869869 AVFrame* frame = rawOutput.frame .get ();
870870 output.streamIndex = streamIndex;
871871 auto & streamInfo = streams_[streamIndex];
872- output.streamType = streams_[streamIndex].stream ->codecpar ->codec_type ;
873872 output.pts = frame->pts ;
874873 output.ptsSeconds =
875874 ptsToSeconds (frame->pts , formatContext_->streams [streamIndex]->time_base );
@@ -930,86 +929,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
930929 }
931930
932931 torch::Tensor outputTensor;
933- if (output.streamType == AVMEDIA_TYPE_VIDEO) {
934- // We need to compare the current frame context with our previous frame
935- // context. If they are different, then we need to re-create our colorspace
936- // conversion objects. We create our colorspace conversion objects late so
937- // that we don't have to depend on the unreliable metadata in the header.
938- // And we sometimes re-create them because it's possible for frame
939- // resolution to change mid-stream. Finally, we want to reuse the colorspace
940- // conversion objects as much as possible for performance reasons.
941- enum AVPixelFormat frameFormat =
942- static_cast <enum AVPixelFormat>(frame->format );
943- auto frameContext = DecodedFrameContext{
944- frame->width ,
945- frame->height ,
946- frameFormat,
947- expectedOutputWidth,
948- expectedOutputHeight};
932+ // We need to compare the current frame context with our previous frame
933+ // context. If they are different, then we need to re-create our colorspace
934+ // conversion objects. We create our colorspace conversion objects late so
935+ // that we don't have to depend on the unreliable metadata in the header.
936+ // And we sometimes re-create them because it's possible for frame
937+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
938+ // conversion objects as much as possible for performance reasons.
939+ enum AVPixelFormat frameFormat =
940+ static_cast <enum AVPixelFormat>(frame->format );
941+ auto frameContext = DecodedFrameContext{
942+ frame->width ,
943+ frame->height ,
944+ frameFormat,
945+ expectedOutputWidth,
946+ expectedOutputHeight};
949947
950- if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
951- outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
952- expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
948+ if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
949+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
950+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
953951
954- if (!streamInfo.swsContext ||
955- streamInfo.prevFrameContext != frameContext) {
956- createSwsContext (streamInfo, frameContext, frame->colorspace );
957- streamInfo.prevFrameContext = frameContext;
958- }
959- int resultHeight =
960- convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
961- // If this check failed, it would mean that the frame wasn't reshaped to
962- // the expected height.
963- // TODO: Can we do the same check for width?
964- TORCH_CHECK (
965- resultHeight == expectedOutputHeight,
966- " resultHeight != expectedOutputHeight: " ,
967- resultHeight,
968- " != " ,
969- expectedOutputHeight);
952+ if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
953+ createSwsContext (streamInfo, frameContext, frame->colorspace );
954+ streamInfo.prevFrameContext = frameContext;
955+ }
956+ int resultHeight =
957+ convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
958+ // If this check failed, it would mean that the frame wasn't reshaped to
959+ // the expected height.
960+ // TODO: Can we do the same check for width?
961+ TORCH_CHECK (
962+ resultHeight == expectedOutputHeight,
963+ " resultHeight != expectedOutputHeight: " ,
964+ resultHeight,
965+ " != " ,
966+ expectedOutputHeight);
967+
968+ output.frame = outputTensor;
969+ } else if (
970+ streamInfo.colorConversionLibrary ==
971+ ColorConversionLibrary::FILTERGRAPH) {
972+ if (!streamInfo.filterState .filterGraph ||
973+ streamInfo.prevFrameContext != frameContext) {
974+ createFilterGraph (streamInfo, expectedOutputHeight, expectedOutputWidth);
975+ streamInfo.prevFrameContext = frameContext;
976+ }
977+ outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
970978
971- output.frame = outputTensor;
972- } else if (
973- streamInfo.colorConversionLibrary ==
974- ColorConversionLibrary::FILTERGRAPH) {
975- if (!streamInfo.filterState .filterGraph ||
976- streamInfo.prevFrameContext != frameContext) {
977- createFilterGraph (
978- streamInfo, expectedOutputHeight, expectedOutputWidth);
979- streamInfo.prevFrameContext = frameContext;
980- }
981- outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
982-
983- // Similarly to above, if this check fails it means the frame wasn't
984- // reshaped to its expected dimensions by filtergraph.
985- auto shape = outputTensor.sizes ();
986- TORCH_CHECK (
987- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
988- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
989- " Expected output tensor of shape " ,
990- expectedOutputHeight,
991- " x" ,
992- expectedOutputWidth,
993- " x3, got " ,
994- shape);
995-
996- if (preAllocatedOutputTensor.has_value ()) {
997- // We have already validated that preAllocatedOutputTensor and
998- // outputTensor have the same shape.
999- preAllocatedOutputTensor.value ().copy_ (outputTensor);
1000- output.frame = preAllocatedOutputTensor.value ();
1001- } else {
1002- output.frame = outputTensor;
1003- }
979+ // Similarly to above, if this check fails it means the frame wasn't
980+ // reshaped to its expected dimensions by filtergraph.
981+ auto shape = outputTensor.sizes ();
982+ TORCH_CHECK (
983+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
984+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
985+ " Expected output tensor of shape " ,
986+ expectedOutputHeight,
987+ " x" ,
988+ expectedOutputWidth,
989+ " x3, got " ,
990+ shape);
991+
992+ if (preAllocatedOutputTensor.has_value ()) {
993+ // We have already validated that preAllocatedOutputTensor and
994+ // outputTensor have the same shape.
995+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
996+ output.frame = preAllocatedOutputTensor.value ();
1004997 } else {
1005- throw std::runtime_error (
1006- " Invalid color conversion library: " +
1007- std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
998+ output.frame = outputTensor;
1008999 }
1009- } else if (output. streamType == AVMEDIA_TYPE_AUDIO) {
1010- // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1011- // audio decoding.
1012- throw std::runtime_error ( " Audio is not supported yet. " );
1000+ } else {
1001+ throw std::runtime_error (
1002+ " Invalid color conversion library: " +
1003+ std::to_string ( static_cast < int >(streamInfo. colorConversionLibrary )) );
10131004 }
10141005}
10151006
0 commit comments