Skip to content

Commit 2ef91b7

Browse files
committed
Only allow get_frames_played_in_range
1 parent c987f9c commit 2ef91b7

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
606606
}
607607

608608
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
609+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
609610
auto frameOutput = getFrameAtIndexInternal(frameIndex);
610611
frameOutput.data = maybePermuteOutputTensor(frameOutput.data);
611612
return frameOutput;
@@ -614,8 +615,6 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
614615
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
615616
int64_t frameIndex,
616617
std::optional<torch::Tensor> preAllocatedOutputTensor) {
617-
validateActiveStream();
618-
619618
const auto& streamInfo = streamInfos_[activeStreamIndex_];
620619
const auto& streamMetadata =
621620
containerMetadata_.allStreamMetadata[activeStreamIndex_];
@@ -628,7 +627,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
628627

629628
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
630629
const std::vector<int64_t>& frameIndices) {
631-
validateActiveStream();
630+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
632631

633632
auto indicesAreSorted =
634633
std::is_sorted(frameIndices.begin(), frameIndices.end());
@@ -685,7 +684,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
685684

686685
VideoDecoder::FrameBatchOutput
687686
VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
688-
validateActiveStream();
687+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
689688

690689
const auto& streamMetadata =
691690
containerMetadata_.allStreamMetadata[activeStreamIndex_];
@@ -714,6 +713,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
714713
}
715714

716715
VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
716+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
717717
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
718718
double frameStartTime =
719719
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
@@ -754,7 +754,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
754754

755755
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
756756
const std::vector<double>& timestamps) {
757-
validateActiveStream();
757+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
758758

759759
const auto& streamMetadata =
760760
containerMetadata_.allStreamMetadata[activeStreamIndex_];
@@ -1845,7 +1845,8 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) {
18451845
// VALIDATION UTILS
18461846
// --------------------------------------------------------------------------
18471847

1848-
void VideoDecoder::validateActiveStream() {
1848+
void VideoDecoder::validateActiveStream(
1849+
std::optional<AVMediaType> avMediaType) {
18491850
auto errorMsg =
18501851
"Provided stream index=" + std::to_string(activeStreamIndex_) +
18511852
" was not previously added.";
@@ -1859,6 +1860,12 @@ void VideoDecoder::validateActiveStream() {
18591860
"Invalid stream index=" + std::to_string(activeStreamIndex_) +
18601861
"; valid indices are in the range [0, " +
18611862
std::to_string(allStreamMetadataSize) + ").");
1863+
1864+
if (avMediaType.has_value()) {
1865+
TORCH_CHECK(
1866+
streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value(),
1867+
"The method you called doesn't support the media type (audio or video)");
1868+
}
18621869
}
18631870

18641871
void VideoDecoder::validateScannedAllStreams(const std::string& msg) {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ class VideoDecoder {
459459
// VALIDATION UTILS
460460
// --------------------------------------------------------------------------
461461

462-
void validateActiveStream();
462+
void validateActiveStream(
463+
std::optional<AVMediaType> avMediaType = std::nullopt);
463464
void validateScannedAllStreams(const std::string& msg);
464465
void validateFrameIndex(
465466
const StreamMetadata& streamMetadata,

test/decoders/test_video_decoder_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
from functools import partial
89

910
os.environ["TORCH_LOGS"] = "output_code"
1011
import json
@@ -18,6 +19,7 @@
1819
from torchcodec.decoders._core import (
1920
_add_video_stream,
2021
_test_frame_pts_equality,
22+
add_audio_stream,
2123
add_video_stream,
2224
create_from_bytes,
2325
create_from_file,
@@ -618,6 +620,24 @@ def test_cuda_decoder(self):
618620
duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3
619621
)
620622

623+
@pytest.mark.parametrize(
624+
"method",
625+
(
626+
partial(get_frame_at_index, frame_index=4),
627+
partial(get_frames_at_indices, frame_indices=[4, 5]),
628+
partial(get_frames_in_range, start=4, stop=5),
629+
partial(get_frame_at_pts, seconds=2),
630+
partial(get_frames_by_pts, timestamps=[0, 1.5]),
631+
),
632+
)
633+
def test_audio_bad_method(self, method):
634+
decoder = create_from_file(str(NASA_AUDIO.path))
635+
add_audio_stream(decoder)
636+
with pytest.raises(
637+
RuntimeError, match="The method you called doesn't support the media type"
638+
):
639+
method(decoder)
640+
621641

622642
if __name__ == "__main__":
623643
pytest.main()

0 commit comments

Comments
 (0)