Skip to content

Commit 3d955c1

Browse files
committed
Add case for start=stop
1 parent 09e6f44 commit 3d955c1

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
842842
torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
843843
double startSeconds,
844844
std::optional<double> stopSecondsOptional) {
845+
validateActiveStream(AVMEDIA_TYPE_AUDIO);
846+
845847
double stopSeconds =
846848
stopSecondsOptional.value_or(std::numeric_limits<double>::max());
847849

@@ -851,7 +853,10 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
851853
") must be less than or equal to stop seconds (" +
852854
std::to_string(stopSeconds) + ".");
853855

854-
validateActiveStream(AVMEDIA_TYPE_AUDIO);
856+
if (startSeconds == stopSeconds) {
857+
// For consistency with video
858+
return torch::empty({0});
859+
}
855860

856861
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
857862

test/decoders/test_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,15 @@ def test_decode_just_one_frame_at_boundaries(self, asset, expected_shape):
723723
)
724724
assert frames.shape == expected_shape
725725

726+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
727+
def test_decode_start_equal_stop(self, asset):
728+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
729+
add_audio_stream(decoder)
730+
frames = get_frames_by_pts_in_range_audio(
731+
decoder, start_seconds=1, stop_seconds=1
732+
)
733+
assert frames.shape == (0,)
734+
726735
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
727736
def test_multiple_calls(self, asset):
728737
# Ensure that multiple calls are OK as long as we're decoding

0 commit comments

Comments
 (0)