Skip to content

Commit ce12f03

Browse files
committed
More validation, more tests
1 parent 82bea4a commit ce12f03

File tree

3 files changed

+54
-28
lines changed

3 files changed

+54
-28
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -842,16 +842,24 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
842842
torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
843843
double startSeconds,
844844
double stopSeconds) {
845+
TORCH_CHECK(
846+
startSeconds <= stopSeconds,
847+
"Start seconds (" + std::to_string(startSeconds) +
848+
") must be less than or equal to stop seconds (" +
849+
std::to_string(stopSeconds) + ".");
850+
845851
validateActiveStream(AVMEDIA_TYPE_AUDIO);
846852

847853
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
848-
double frameStartTime =
849-
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
850-
double frameEndTime = ptsToSeconds(
851-
streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration,
852-
streamInfo.timeBase);
853854

854-
TORCH_CHECK(startSeconds > frameEndTime, "OSKOOOOOUUUUUUURRRRRR");
855+
auto lastDecodedFrameIsPlayedAtStopSeconds =
856+
[this, &streamInfo, stopSeconds]() {
857+
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
858+
return (
859+
streamInfo.lastDecodedAvFramePts <= stopPts and
860+
stopPts <= streamInfo.lastDecodedAvFramePts +
861+
streamInfo.lastDecodedAvFrameDuration);
862+
};
855863

856864
setCursorPtsInSeconds(startSeconds);
857865

@@ -860,26 +868,19 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
860868
// sample rate, so in theory we know the number of output samples.
861869
std::vector<torch::Tensor> tensors;
862870

863-
while (true) {
864-
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
865-
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
866-
return (avFrame->pts >= activeStreamInfo.discardFramesBeforePts) ||
867-
(avFrame->pts < activeStreamInfo.discardFramesBeforePts &&
868-
activeStreamInfo.discardFramesBeforePts <
869-
avFrame->pts + avFrame->duration);
870-
});
871-
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
872-
tensors.push_back(frameOutput.data);
873-
874-
double lastFrameStartPts =
875-
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
876-
double lastFrameEndPts = ptsToSeconds(
877-
streamInfo.lastDecodedAvFramePts +
878-
streamInfo.lastDecodedAvFrameDuration,
879-
streamInfo.timeBase);
880-
881-
if (lastFrameStartPts <= stopSeconds and stopSeconds <= lastFrameEndPts) {
882-
break;
871+
bool reachedEOF = false;
872+
while (!lastDecodedFrameIsPlayedAtStopSeconds() && !reachedEOF) {
873+
try {
874+
AVFrameStream avFrameStream =
875+
decodeAVFrame([&streamInfo](AVFrame* avFrame) {
876+
return (
877+
streamInfo.discardFramesBeforePts <
878+
avFrame->pts + getDuration(avFrame));
879+
});
880+
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
881+
tensors.push_back(frameOutput.data);
882+
} catch (const EndOfFileException& e) {
883+
reachedEOF = true;
883884
}
884885
}
885886
return torch::cat(tensors, 1);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ class VideoDecoder {
339339
// The current position of the cursor in the stream, and associated frame
340340
// duration.
341341
int64_t lastDecodedAvFramePts = 0;
342-
int64_t lastDecodedAvFrameDuration = -1;
342+
int64_t lastDecodedAvFrameDuration = 0;
343343
// The desired position of the cursor in the stream. We send frames >=
344344
// this pts to the user when they request a frame.
345345
// We update this field if the user requested a seek. This typically

test/decoders/test_ops.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,12 +664,20 @@ def test_audio_decode_all_samples_with_next(self, asset):
664664
assert_frames_equal(all_frames, reference_frames)
665665

666666
@pytest.mark.parametrize(
667-
"range", ("begin_to_end", "at_frame_boundaries", "not_at_frame_boundaries")
667+
"range",
668+
(
669+
"begin_to_end",
670+
"begin_to_beyond_end",
671+
"at_frame_boundaries",
672+
"not_at_frame_boundaries",
673+
),
668674
)
669675
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
670676
def test_get_frames_by_pts_in_range_audio(self, range, asset):
671677
if range == "begin_to_end":
672678
start_seconds, stop_seconds = 0, asset.duration_seconds
679+
elif range == "begin_to_beyond_end":
680+
start_seconds, stop_seconds = 0, asset.duration_seconds + 10
673681
elif range == "at_frame_boundaries":
674682
start_seconds = asset.frames[asset.default_stream_index][10].pts_seconds
675683
stop_seconds = asset.frames[asset.default_stream_index][40].pts_seconds
@@ -687,6 +695,9 @@ def test_get_frames_by_pts_in_range_audio(self, range, asset):
687695
decoder = create_from_file(str(asset.path), seek_mode="approximate")
688696
add_audio_stream(decoder)
689697

698+
# stop_offset logic: if stop_seconds is at a frame boundary i.e. when a
699+
# frame starts, then that frame should *not* be included in the output.
700+
# Otherwise, it should be part of it, hence why we add 1 to `stop=`.
690701
stop_offset = 0 if range == "at_frame_boundaries" else 1
691702
reference_frames = asset.get_frame_data_by_range(
692703
start=asset.get_frame_index(pts_seconds=start_seconds),
@@ -711,6 +722,20 @@ def test_decode_epsilon_range(self, asset, expected_shape):
711722
)
712723
assert frames.shape == expected_shape
713724

725+
@pytest.mark.parametrize(
726+
"asset, expected_shape", ((NASA_AUDIO, (2, 1024)), (NASA_AUDIO_MP3, (2, 576)))
727+
)
728+
def test_decode_just_one_frame_at_boundaries(self, asset, expected_shape):
729+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
730+
add_audio_stream(decoder)
731+
732+
start_seconds = asset.frames[asset.default_stream_index][10].pts_seconds
733+
stop_seconds = asset.frames[asset.default_stream_index][11].pts_seconds
734+
frames = get_frames_by_pts_in_range_audio(
735+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
736+
)
737+
assert frames.shape == expected_shape
738+
714739
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
715740
def test_seek_and_next_audio(self, asset):
716741
decoder = create_from_file(str(asset.path), seek_mode="approximate")

0 commit comments

Comments
 (0)