@@ -418,8 +418,9 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions(
418418 }
419419}
420420
421- void VideoDecoder::addVideoStreamDecoder (
421+ void VideoDecoder::addStream (
422422 int streamIndex,
423+ AVMediaType mediaType,
423424 const VideoStreamOptions& videoStreamOptions) {
424425 TORCH_CHECK (
425426 activeStreamIndex_ == NO_ACTIVE_STREAM,
@@ -429,30 +430,37 @@ void VideoDecoder::addVideoStreamDecoder(
429430 AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
430431
431432 activeStreamIndex_ = av_find_best_stream (
432- formatContext_.get (), AVMEDIA_TYPE_VIDEO, streamIndex, -1 , &avCodec, 0 );
433+ formatContext_.get (), mediaType, streamIndex, -1 , &avCodec, 0 );
434+
433435 if (activeStreamIndex_ < 0 ) {
434- throw std::invalid_argument (" No valid stream found in input file." );
436+ throw std::invalid_argument (
437+ " No valid stream found in input file. Is " +
438+ std::to_string (streamIndex) + " of the desired media type?" );
435439 }
440+
436441 TORCH_CHECK (avCodec != nullptr );
437442
438443 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
439444 streamInfo.streamIndex = activeStreamIndex_;
440445 streamInfo.timeBase = formatContext_->streams [activeStreamIndex_]->time_base ;
441446 streamInfo.stream = formatContext_->streams [activeStreamIndex_];
447+ streamInfo.avMediaType = mediaType;
442448
443- if (streamInfo.stream ->codecpar ->codec_type != AVMEDIA_TYPE_VIDEO) {
444- throw std::invalid_argument (
445- " Stream with index " + std::to_string (activeStreamIndex_) +
446- " is not a video stream." );
447- }
449+ // This should never happen, checking just to be safe.
450+ TORCH_CHECK (
451+ streamInfo.stream ->codecpar ->codec_type == mediaType,
452+ " FFmpeg found stream with index " , activeStreamIndex_, " which is of the wrong media type." );
448453
449- if (videoStreamOptions.device .type () == torch::kCUDA ) {
454+
455+ if (mediaType == AVMEDIA_TYPE_VIDEO &&
456+ videoStreamOptions.device .type () == torch::kCUDA ) {
450457 avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream (
451458 findCudaCodec (
452459 videoStreamOptions.device , streamInfo.stream ->codecpar ->codec_id )
453460 .value_or (avCodec));
454461 }
455462
463+ // TODO figure out whether this should be VIDEO only
456464 StreamMetadata& streamMetadata =
457465 containerMetadata_.allStreamMetadata [activeStreamIndex_];
458466 if (seekMode_ == SeekMode::approximate &&
@@ -465,37 +473,34 @@ void VideoDecoder::addVideoStreamDecoder(
465473
466474 AVCodecContext* codecContext = avcodec_alloc_context3 (avCodec);
467475 TORCH_CHECK (codecContext != nullptr );
468- codecContext->thread_count = videoStreamOptions.ffmpegThreadCount .value_or (0 );
476+ codecContext->thread_count =
477+ videoStreamOptions.ffmpegThreadCount .value_or (0 ); // TODO VIDEO ONLY?
469478 streamInfo.codecContext .reset (codecContext);
470479
471480 int retVal = avcodec_parameters_to_context (
472481 streamInfo.codecContext .get (), streamInfo.stream ->codecpar );
473482 TORCH_CHECK_EQ (retVal, AVSUCCESS);
474483
475- if (videoStreamOptions.device .type () == torch::kCPU ) {
476- // No more initialization needed for CPU.
477- } else if (videoStreamOptions.device .type () == torch::kCUDA ) {
478- initializeContextOnCuda (videoStreamOptions.device , codecContext);
479- } else {
480- TORCH_CHECK (
481- false , " Invalid device type: " + videoStreamOptions.device .str ());
484+ if (mediaType == AVMEDIA_TYPE_VIDEO) {
485+ if (videoStreamOptions.device .type () == torch::kCPU ) {
486+ // No more initialization needed for CPU.
487+ } else if (videoStreamOptions.device .type () == torch::kCUDA ) {
488+ initializeContextOnCuda (videoStreamOptions.device , codecContext);
489+ } else {
490+ TORCH_CHECK (
491+ false , " Invalid device type: " + videoStreamOptions.device .str ());
492+ }
493+ streamInfo.videoStreamOptions = videoStreamOptions;
482494 }
483- streamInfo.videoStreamOptions = videoStreamOptions;
484495
485496 retVal = avcodec_open2 (streamInfo.codecContext .get (), avCodec, nullptr );
486497 if (retVal < AVSUCCESS) {
487498 throw std::invalid_argument (getFFMPEGErrorStringFromErrorCode (retVal));
488499 }
489500
490501 codecContext->time_base = streamInfo.stream ->time_base ;
491-
492- containerMetadata_.allStreamMetadata [activeStreamIndex_].width =
493- codecContext->width ;
494- containerMetadata_.allStreamMetadata [activeStreamIndex_].height =
495- codecContext->height ;
496- auto codedId = codecContext->codec_id ;
497502 containerMetadata_.allStreamMetadata [activeStreamIndex_].codecName =
498- std::string (avcodec_get_name (codedId ));
503+ std::string (avcodec_get_name (codecContext-> codec_id ));
499504
500505 // We will only need packets from the active stream, so we tell FFmpeg to
501506 // discard packets from the other streams. Note that av_read_frame() may still
@@ -506,6 +511,18 @@ void VideoDecoder::addVideoStreamDecoder(
506511 formatContext_->streams [i]->discard = AVDISCARD_ALL;
507512 }
508513 }
514+ }
515+
516+ void VideoDecoder::addVideoStream (
517+ int streamIndex,
518+ const VideoStreamOptions& videoStreamOptions) {
519+ addStream (streamIndex, AVMEDIA_TYPE_VIDEO, videoStreamOptions);
520+
521+ auto & streamInfo = streamInfos_[activeStreamIndex_];
522+ containerMetadata_.allStreamMetadata [activeStreamIndex_].width =
523+ streamInfo.codecContext ->width ;
524+ containerMetadata_.allStreamMetadata [activeStreamIndex_].height =
525+ streamInfo.codecContext ->height ;
509526
510527 // By default, we want to use swscale for color conversion because it is
511528 // faster. However, it has width requirements, so we may need to fall back
@@ -514,7 +531,7 @@ void VideoDecoder::addVideoStreamDecoder(
514531 // swscale's width requirements to be violated. We don't expose the ability to
515532 // choose color conversion library publicly; we only use this ability
516533 // internally.
517- int width = videoStreamOptions.width .value_or (codecContext->width );
534+ int width = videoStreamOptions.width .value_or (streamInfo. codecContext ->width );
518535
519536 // swscale requires widths to be multiples of 32:
520537 // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
@@ -527,6 +544,10 @@ void VideoDecoder::addVideoStreamDecoder(
527544 videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
528545}
529546
547+ void VideoDecoder::addAudioStream (int streamIndex) {
548+ addStream (streamIndex, AVMEDIA_TYPE_AUDIO);
549+ }
550+
530551// --------------------------------------------------------------------------
531552// HIGH-LEVEL DECODING ENTRY-POINTS
532553// --------------------------------------------------------------------------
@@ -1051,7 +1072,6 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10511072 AVFrame* avFrame = avFrameStream.avFrame .get ();
10521073 frameOutput.streamIndex = streamIndex;
10531074 auto & streamInfo = streamInfos_[streamIndex];
1054- TORCH_CHECK (streamInfo.stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO);
10551075 frameOutput.ptsSeconds = ptsToSeconds (
10561076 avFrame->pts , formatContext_->streams [streamIndex]->time_base );
10571077 frameOutput.durationSeconds = ptsToSeconds (
0 commit comments