@@ -162,12 +162,11 @@ void VideoDecoder::initializeDecoder() {
162
162
av_q2d (avStream->time_base ) * avStream->duration ;
163
163
}
164
164
165
- double fps = av_q2d (avStream->r_frame_rate );
166
- if (fps > 0 ) {
167
- streamMetadata.averageFps = fps;
168
- }
169
-
170
165
if (avStream->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO) {
166
+ double fps = av_q2d (avStream->r_frame_rate );
167
+ if (fps > 0 ) {
168
+ streamMetadata.averageFps = fps;
169
+ }
171
170
containerMetadata_.numVideoStreams ++;
172
171
} else if (avStream->codecpar ->codec_type == AVMEDIA_TYPE_AUDIO) {
173
172
containerMetadata_.numAudioStreams ++;
@@ -340,7 +339,7 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
340
339
}
341
340
342
341
torch::Tensor VideoDecoder::getKeyFrameIndices () {
343
- validateActiveStream ();
342
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
344
343
validateScannedAllStreams (" getKeyFrameIndices" );
345
344
346
345
const std::vector<FrameInfo>& keyFrames =
@@ -409,84 +408,76 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions(
409
408
}
410
409
}
411
410
412
- void VideoDecoder::addVideoStream (
411
+ void VideoDecoder::addStream (
413
412
int streamIndex,
414
- const VideoStreamOptions& videoStreamOptions) {
413
+ AVMediaType mediaType,
414
+ const torch::Device& device,
415
+ std::optional<int > ffmpegThreadCount) {
415
416
TORCH_CHECK (
416
417
activeStreamIndex_ == NO_ACTIVE_STREAM,
417
418
" Can only add one single stream." );
419
+ TORCH_CHECK (
420
+ mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO,
421
+ " Can only add video or audio streams." );
418
422
TORCH_CHECK (formatContext_.get () != nullptr );
419
423
420
424
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
421
425
422
426
activeStreamIndex_ = av_find_best_stream (
423
- formatContext_.get (), AVMEDIA_TYPE_VIDEO, streamIndex, -1 , &avCodec, 0 );
427
+ formatContext_.get (), mediaType, streamIndex, -1 , &avCodec, 0 );
428
+
424
429
if (activeStreamIndex_ < 0 ) {
425
- throw std::invalid_argument (" No valid stream found in input file." );
430
+ throw std::invalid_argument (
431
+ " No valid stream found in input file. Is " +
432
+ std::to_string (streamIndex) + " of the desired media type?" );
426
433
}
434
+
427
435
TORCH_CHECK (avCodec != nullptr );
428
436
429
437
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
430
438
streamInfo.streamIndex = activeStreamIndex_;
431
439
streamInfo.timeBase = formatContext_->streams [activeStreamIndex_]->time_base ;
432
440
streamInfo.stream = formatContext_->streams [activeStreamIndex_];
441
+ streamInfo.avMediaType = mediaType;
433
442
434
- if (streamInfo.stream ->codecpar ->codec_type != AVMEDIA_TYPE_VIDEO) {
435
- throw std::invalid_argument (
436
- " Stream with index " + std::to_string (activeStreamIndex_) +
437
- " is not a video stream." );
438
- }
439
-
440
- if (videoStreamOptions.device .type () == torch::kCUDA ) {
443
+ // This should never happen, checking just to be safe.
444
+ TORCH_CHECK (
445
+ streamInfo.stream ->codecpar ->codec_type == mediaType,
446
+ " FFmpeg found stream with index " ,
447
+ activeStreamIndex_,
448
+ " which is of the wrong media type." );
449
+
450
+ // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
451
+ // addStream() which is supposed to be generic
452
+ if (mediaType == AVMEDIA_TYPE_VIDEO && device.type () == torch::kCUDA ) {
441
453
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream (
442
- findCudaCodec (
443
- videoStreamOptions.device , streamInfo.stream ->codecpar ->codec_id )
454
+ findCudaCodec (device, streamInfo.stream ->codecpar ->codec_id )
444
455
.value_or (avCodec));
445
456
}
446
457
447
- StreamMetadata& streamMetadata =
448
- containerMetadata_.allStreamMetadata [activeStreamIndex_];
449
- if (seekMode_ == SeekMode::approximate &&
450
- !streamMetadata.averageFps .has_value ()) {
451
- throw std::runtime_error (
452
- " Seek mode is approximate, but stream " +
453
- std::to_string (activeStreamIndex_) +
454
- " does not have an average fps in its metadata." );
455
- }
456
-
457
458
AVCodecContext* codecContext = avcodec_alloc_context3 (avCodec);
458
459
TORCH_CHECK (codecContext != nullptr );
459
- codecContext->thread_count = videoStreamOptions.ffmpegThreadCount .value_or (0 );
460
460
streamInfo.codecContext .reset (codecContext);
461
461
462
462
int retVal = avcodec_parameters_to_context (
463
463
streamInfo.codecContext .get (), streamInfo.stream ->codecpar );
464
464
TORCH_CHECK_EQ (retVal, AVSUCCESS);
465
465
466
- if (videoStreamOptions.device .type () == torch::kCPU ) {
467
- // No more initialization needed for CPU.
468
- } else if (videoStreamOptions.device .type () == torch::kCUDA ) {
469
- initializeContextOnCuda (videoStreamOptions.device , codecContext);
470
- } else {
471
- TORCH_CHECK (
472
- false , " Invalid device type: " + videoStreamOptions.device .str ());
466
+ streamInfo.codecContext ->thread_count = ffmpegThreadCount.value_or (0 );
467
+
468
+ // TODO_CODE_QUALITY same as above.
469
+ if (mediaType == AVMEDIA_TYPE_VIDEO && device.type () == torch::kCUDA ) {
470
+ initializeContextOnCuda (device, codecContext);
473
471
}
474
- streamInfo.videoStreamOptions = videoStreamOptions;
475
472
476
473
retVal = avcodec_open2 (streamInfo.codecContext .get (), avCodec, nullptr );
477
474
if (retVal < AVSUCCESS) {
478
475
throw std::invalid_argument (getFFMPEGErrorStringFromErrorCode (retVal));
479
476
}
480
477
481
478
codecContext->time_base = streamInfo.stream ->time_base ;
482
-
483
- containerMetadata_.allStreamMetadata [activeStreamIndex_].width =
484
- codecContext->width ;
485
- containerMetadata_.allStreamMetadata [activeStreamIndex_].height =
486
- codecContext->height ;
487
- auto codedId = codecContext->codec_id ;
488
479
containerMetadata_.allStreamMetadata [activeStreamIndex_].codecName =
489
- std::string (avcodec_get_name (codedId ));
480
+ std::string (avcodec_get_name (codecContext-> codec_id ));
490
481
491
482
// We will only need packets from the active stream, so we tell FFmpeg to
492
483
// discard packets from the other streams. Note that av_read_frame() may still
@@ -497,6 +488,38 @@ void VideoDecoder::addVideoStream(
497
488
formatContext_->streams [i]->discard = AVDISCARD_ALL;
498
489
}
499
490
}
491
+ }
492
+
493
+ void VideoDecoder::addVideoStream (
494
+ int streamIndex,
495
+ const VideoStreamOptions& videoStreamOptions) {
496
+ TORCH_CHECK (
497
+ videoStreamOptions.device .type () == torch::kCPU ||
498
+ videoStreamOptions.device .type () == torch::kCUDA ,
499
+ " Invalid device type: " + videoStreamOptions.device .str ());
500
+
501
+ addStream (
502
+ streamIndex,
503
+ AVMEDIA_TYPE_VIDEO,
504
+ videoStreamOptions.device ,
505
+ videoStreamOptions.ffmpegThreadCount );
506
+
507
+ auto & streamMetadata =
508
+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
509
+
510
+ if (seekMode_ == SeekMode::approximate &&
511
+ !streamMetadata.averageFps .has_value ()) {
512
+ throw std::runtime_error (
513
+ " Seek mode is approximate, but stream " +
514
+ std::to_string (activeStreamIndex_) +
515
+ " does not have an average fps in its metadata." );
516
+ }
517
+
518
+ auto & streamInfo = streamInfos_[activeStreamIndex_];
519
+ streamInfo.videoStreamOptions = videoStreamOptions;
520
+
521
+ streamMetadata.width = streamInfo.codecContext ->width ;
522
+ streamMetadata.height = streamInfo.codecContext ->height ;
500
523
501
524
// By default, we want to use swscale for color conversion because it is
502
525
// faster. However, it has width requirements, so we may need to fall back
@@ -505,7 +528,7 @@ void VideoDecoder::addVideoStream(
505
528
// swscale's width requirements to be violated. We don't expose the ability to
506
529
// choose color conversion library publicly; we only use this ability
507
530
// internally.
508
- int width = videoStreamOptions.width .value_or (codecContext->width );
531
+ int width = videoStreamOptions.width .value_or (streamInfo. codecContext ->width );
509
532
510
533
// swscale requires widths to be multiples of 32:
511
534
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
@@ -518,6 +541,21 @@ void VideoDecoder::addVideoStream(
518
541
videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
519
542
}
520
543
544
+ void VideoDecoder::addAudioStream (int streamIndex) {
545
+ TORCH_CHECK (
546
+ seekMode_ == SeekMode::approximate,
547
+ " seek_mode must be 'approximate' for audio streams." );
548
+
549
+ addStream (streamIndex, AVMEDIA_TYPE_AUDIO);
550
+
551
+ auto & streamInfo = streamInfos_[activeStreamIndex_];
552
+ auto & streamMetadata =
553
+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
554
+ streamMetadata.sampleRate =
555
+ static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
556
+ streamMetadata.numChannels = getNumChannels (streamInfo.codecContext );
557
+ }
558
+
521
559
// --------------------------------------------------------------------------
522
560
// HIGH-LEVEL DECODING ENTRY-POINTS
523
561
// --------------------------------------------------------------------------
@@ -546,7 +584,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
546
584
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal (
547
585
int64_t frameIndex,
548
586
std::optional<torch::Tensor> preAllocatedOutputTensor) {
549
- validateActiveStream ();
587
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
550
588
551
589
const auto & streamInfo = streamInfos_[activeStreamIndex_];
552
590
const auto & streamMetadata =
@@ -560,7 +598,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
560
598
561
599
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices (
562
600
const std::vector<int64_t >& frameIndices) {
563
- validateActiveStream ();
601
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
564
602
565
603
auto indicesAreSorted =
566
604
std::is_sorted (frameIndices.begin (), frameIndices.end ());
@@ -619,7 +657,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
619
657
620
658
VideoDecoder::FrameBatchOutput
621
659
VideoDecoder::getFramesInRange (int64_t start, int64_t stop, int64_t step) {
622
- validateActiveStream ();
660
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
623
661
624
662
const auto & streamMetadata =
625
663
containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -690,7 +728,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
690
728
691
729
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt (
692
730
const std::vector<double >& timestamps) {
693
- validateActiveStream ();
731
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
694
732
695
733
const auto & streamMetadata =
696
734
containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -721,7 +759,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
721
759
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange (
722
760
double startSeconds,
723
761
double stopSeconds) {
724
- validateActiveStream ();
762
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
725
763
726
764
const auto & streamMetadata =
727
765
containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -860,7 +898,7 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
860
898
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
861
899
// the comment of canWeAvoidSeeking() for details.
862
900
void VideoDecoder::maybeSeekToBeforeDesiredPts () {
863
- validateActiveStream ();
901
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
864
902
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
865
903
866
904
int64_t desiredPts =
@@ -907,7 +945,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
907
945
908
946
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
909
947
std::function<bool (AVFrame*)> filterFunction) {
910
- validateActiveStream ();
948
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
911
949
912
950
resetDecodeStats ();
913
951
@@ -1587,7 +1625,8 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) {
1587
1625
// VALIDATION UTILS
1588
1626
// --------------------------------------------------------------------------
1589
1627
1590
- void VideoDecoder::validateActiveStream () {
1628
+ void VideoDecoder::validateActiveStream (
1629
+ std::optional<AVMediaType> avMediaType) {
1591
1630
auto errorMsg =
1592
1631
" Provided stream index=" + std::to_string (activeStreamIndex_) +
1593
1632
" was not previously added." ;
@@ -1601,6 +1640,14 @@ void VideoDecoder::validateActiveStream() {
1601
1640
" Invalid stream index=" + std::to_string (activeStreamIndex_) +
1602
1641
" ; valid indices are in the range [0, " +
1603
1642
std::to_string (allStreamMetadataSize) + " )." );
1643
+
1644
+ if (avMediaType.has_value ()) {
1645
+ TORCH_CHECK (
1646
+ streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value (),
1647
+ " The method you called isn't supported. " ,
1648
+ " If you're seeing this error, you are probably trying to call an " ,
1649
+ " unsupported method on an audio stream." );
1650
+ }
1604
1651
}
1605
1652
1606
1653
void VideoDecoder::validateScannedAllStreams (const std::string& msg) {
@@ -1648,7 +1695,7 @@ void VideoDecoder::resetDecodeStats() {
1648
1695
}
1649
1696
1650
1697
double VideoDecoder::getPtsSecondsForFrame (int64_t frameIndex) {
1651
- validateActiveStream ();
1698
+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
1652
1699
validateScannedAllStreams (" getPtsSecondsForFrame" );
1653
1700
1654
1701
const auto & streamInfo = streamInfos_[activeStreamIndex_];
0 commit comments