@@ -871,7 +871,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
871871 AVFrame* frame = rawOutput.frame .get ();
872872 output.streamIndex = streamIndex;
873873 auto & streamInfo = streams_[streamIndex];
874- output. streamType = streams_[streamIndex]. stream ->codecpar ->codec_type ;
874+ TORCH_CHECK (streamInfo. stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO) ;
875875 output.pts = frame->pts ;
876876 output.ptsSeconds =
877877 ptsToSeconds (frame->pts , formatContext_->streams [streamIndex]->time_base );
@@ -932,86 +932,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
932932 }
933933
934934 torch::Tensor outputTensor;
935- if (output.streamType == AVMEDIA_TYPE_VIDEO) {
936- // We need to compare the current frame context with our previous frame
937- // context. If they are different, then we need to re-create our colorspace
938- // conversion objects. We create our colorspace conversion objects late so
939- // that we don't have to depend on the unreliable metadata in the header.
940- // And we sometimes re-create them because it's possible for frame
941- // resolution to change mid-stream. Finally, we want to reuse the colorspace
942- // conversion objects as much as possible for performance reasons.
943- enum AVPixelFormat frameFormat =
944- static_cast <enum AVPixelFormat>(frame->format );
945- auto frameContext = DecodedFrameContext{
946- frame->width ,
947- frame->height ,
948- frameFormat,
949- expectedOutputWidth,
950- expectedOutputHeight};
935+ // We need to compare the current frame context with our previous frame
936+ // context. If they are different, then we need to re-create our colorspace
937+ // conversion objects. We create our colorspace conversion objects late so
938+ // that we don't have to depend on the unreliable metadata in the header.
939+ // And we sometimes re-create them because it's possible for frame
940+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
941+ // conversion objects as much as possible for performance reasons.
942+ enum AVPixelFormat frameFormat =
943+ static_cast <enum AVPixelFormat>(frame->format );
944+ auto frameContext = DecodedFrameContext{
945+ frame->width ,
946+ frame->height ,
947+ frameFormat,
948+ expectedOutputWidth,
949+ expectedOutputHeight};
951950
952- if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
953- outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
954- expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
951+ if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
952+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
953+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
955954
956- if (!streamInfo.swsContext ||
957- streamInfo.prevFrameContext != frameContext) {
958- createSwsContext (streamInfo, frameContext, frame->colorspace );
959- streamInfo.prevFrameContext = frameContext;
960- }
961- int resultHeight =
962- convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
963- // If this check failed, it would mean that the frame wasn't reshaped to
964- // the expected height.
965- // TODO: Can we do the same check for width?
966- TORCH_CHECK (
967- resultHeight == expectedOutputHeight,
968- " resultHeight != expectedOutputHeight: " ,
969- resultHeight,
970- " != " ,
971- expectedOutputHeight);
955+ if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
956+ createSwsContext (streamInfo, frameContext, frame->colorspace );
957+ streamInfo.prevFrameContext = frameContext;
958+ }
959+ int resultHeight =
960+ convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
961+ // If this check failed, it would mean that the frame wasn't reshaped to
962+ // the expected height.
963+ // TODO: Can we do the same check for width?
964+ TORCH_CHECK (
965+ resultHeight == expectedOutputHeight,
966+ " resultHeight != expectedOutputHeight: " ,
967+ resultHeight,
968+ " != " ,
969+ expectedOutputHeight);
970+
971+ output.frame = outputTensor;
972+ } else if (
973+ streamInfo.colorConversionLibrary ==
974+ ColorConversionLibrary::FILTERGRAPH) {
975+ if (!streamInfo.filterState .filterGraph ||
976+ streamInfo.prevFrameContext != frameContext) {
977+ createFilterGraph (streamInfo, expectedOutputHeight, expectedOutputWidth);
978+ streamInfo.prevFrameContext = frameContext;
979+ }
980+ outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
972981
973- output.frame = outputTensor;
974- } else if (
975- streamInfo.colorConversionLibrary ==
976- ColorConversionLibrary::FILTERGRAPH) {
977- if (!streamInfo.filterState .filterGraph ||
978- streamInfo.prevFrameContext != frameContext) {
979- createFilterGraph (
980- streamInfo, expectedOutputHeight, expectedOutputWidth);
981- streamInfo.prevFrameContext = frameContext;
982- }
983- outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
984-
985- // Similarly to above, if this check fails it means the frame wasn't
986- // reshaped to its expected dimensions by filtergraph.
987- auto shape = outputTensor.sizes ();
988- TORCH_CHECK (
989- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
990- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
991- " Expected output tensor of shape " ,
992- expectedOutputHeight,
993- " x" ,
994- expectedOutputWidth,
995- " x3, got " ,
996- shape);
997-
998- if (preAllocatedOutputTensor.has_value ()) {
999- // We have already validated that preAllocatedOutputTensor and
1000- // outputTensor have the same shape.
1001- preAllocatedOutputTensor.value ().copy_ (outputTensor);
1002- output.frame = preAllocatedOutputTensor.value ();
1003- } else {
1004- output.frame = outputTensor;
1005- }
982+ // Similarly to above, if this check fails it means the frame wasn't
983+ // reshaped to its expected dimensions by filtergraph.
984+ auto shape = outputTensor.sizes ();
985+ TORCH_CHECK (
986+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
987+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
988+ " Expected output tensor of shape " ,
989+ expectedOutputHeight,
990+ " x" ,
991+ expectedOutputWidth,
992+ " x3, got " ,
993+ shape);
994+
995+ if (preAllocatedOutputTensor.has_value ()) {
996+ // We have already validated that preAllocatedOutputTensor and
997+ // outputTensor have the same shape.
998+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
999+ output.frame = preAllocatedOutputTensor.value ();
10061000 } else {
1007- throw std::runtime_error (
1008- " Invalid color conversion library: " +
1009- std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
1001+ output.frame = outputTensor;
10101002 }
1011- } else if (output. streamType == AVMEDIA_TYPE_AUDIO) {
1012- // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1013- // audio decoding.
1014- throw std::runtime_error ( " Audio is not supported yet. " );
1003+ } else {
1004+ throw std::runtime_error (
1005+ " Invalid color conversion library: " +
1006+ std::to_string ( static_cast < int >(streamInfo. colorConversionLibrary )) );
10151007 }
10161008}
10171009
0 commit comments