@@ -606,6 +606,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
606606}
607607
608608VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex (int64_t frameIndex) {
609+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
609610 auto frameOutput = getFrameAtIndexInternal (frameIndex);
610611 frameOutput.data = maybePermuteOutputTensor (frameOutput.data );
611612 return frameOutput;
@@ -614,8 +615,6 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
614615VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal (
615616 int64_t frameIndex,
616617 std::optional<torch::Tensor> preAllocatedOutputTensor) {
617- validateActiveStream ();
618-
619618 const auto & streamInfo = streamInfos_[activeStreamIndex_];
620619 const auto & streamMetadata =
621620 containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -628,7 +627,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
628627
629628VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices (
630629 const std::vector<int64_t >& frameIndices) {
631- validateActiveStream ();
630+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
632631
633632 auto indicesAreSorted =
634633 std::is_sorted (frameIndices.begin (), frameIndices.end ());
@@ -685,7 +684,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
685684
686685VideoDecoder::FrameBatchOutput
687686VideoDecoder::getFramesInRange (int64_t start, int64_t stop, int64_t step) {
688- validateActiveStream ();
687+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
689688
690689 const auto & streamMetadata =
691690 containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -714,6 +713,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
714713}
715714
716715VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt (double seconds) {
716+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
717717 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
718718 double frameStartTime =
719719 ptsToSeconds (streamInfo.lastDecodedAvFramePts , streamInfo.timeBase );
@@ -754,7 +754,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
754754
755755VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt (
756756 const std::vector<double >& timestamps) {
757- validateActiveStream ();
757+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
758758
759759 const auto & streamMetadata =
760760 containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -1845,7 +1845,8 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) {
18451845// VALIDATION UTILS
18461846// --------------------------------------------------------------------------
18471847
1848- void VideoDecoder::validateActiveStream () {
1848+ void VideoDecoder::validateActiveStream (
1849+ std::optional<AVMediaType> avMediaType) {
18491850 auto errorMsg =
18501851 " Provided stream index=" + std::to_string (activeStreamIndex_) +
18511852 " was not previously added." ;
@@ -1859,6 +1860,12 @@ void VideoDecoder::validateActiveStream() {
18591860 " Invalid stream index=" + std::to_string (activeStreamIndex_) +
18601861 " ; valid indices are in the range [0, " +
18611862 std::to_string (allStreamMetadataSize) + " )." );
1863+
1864+ if (avMediaType.has_value ()) {
1865+ TORCH_CHECK (
1866+ streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value (),
1867+ " The method you called doesn't support the media type (audio or video)" );
1868+ }
18621869}
18631870
18641871void VideoDecoder::validateScannedAllStreams (const std::string& msg) {
0 commit comments