Skip to content

Commit 8a4a444

Browse files
committed
Add correct support for getFramesPlayedInRange
1 parent 2ef91b7 commit 8a4a444

File tree

5 files changed

+130
-112
lines changed

5 files changed

+130
-112
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 26 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,22 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
787787
double stopSeconds) {
788788
validateActiveStream();
789789

790+
// Because we currently never seek with audio streams, we prevent users from
791+
// calling this method twice. We could allow multiple calls in the future.
792+
// Assuming 2 consecutive calls:
793+
// ```
794+
// getFramesPlayedInRange(startSeconds1, stopSeconds1);
795+
// getFramesPlayedInRange(startSeconds2, stopSeconds2);
796+
// ```
797+
// We would need to seek back to 0 iff startSeconds2 <= stopSeconds1. This
798+
// logic is not implemented for now, so we just error.
799+
800+
TORCH_CHECK(
801+
streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO ||
802+
!alreadyCalledGetFramesPlayedInRange_,
803+
"Can only decode once with audio stream. Re-create a decoder object if needed.")
804+
alreadyCalledGetFramesPlayedInRange_ = true;
805+
790806
TORCH_CHECK(
791807
startSeconds <= stopSeconds,
792808
"Start seconds (" + std::to_string(startSeconds) +
@@ -869,30 +885,6 @@ void VideoDecoder::setCursorPtsInSeconds(double seconds) {
869885
desiredPtsSeconds_ = seconds;
870886
}
871887

872-
bool VideoDecoder::canWeAvoidSeekingAudio(double desiredPtsSeconds) const {
873-
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
874-
int64_t targetPts = *desiredPtsSeconds_ * streamInfo.timeBase.den;
875-
int64_t lastDecodedAvFramePts = streamInfo.lastDecodedAvFramePts;
876-
877-
if (targetPts <= lastDecodedAvFramePts) {
878-
return false;
879-
}
880-
881-
// We can skip seeking if we want to decoder frame `i` and we just decoded
882-
// frame `i - 1`. Note this involves a `log(numFrames)` complexity for each
883-
// decoded frame.
884-
// TODO we should bypass this log(numFrames) logic when calling range APIs
885-
// where the step is 1, because we are sure in this case that all frames
886-
// (except the first one) are consecutive. See a POC at
887-
// https://github.com/pytorch/torchcodec/pull/514
888-
double lastDecodedAvFramePtsSeconds =
889-
ptsToSeconds(lastDecodedAvFramePts, streamInfo.timeBase);
890-
int64_t lastDecodedAvFrameIndex =
891-
secondsToIndexLowerBound(lastDecodedAvFramePtsSeconds);
892-
int64_t targetFrameIndex = secondsToIndexLowerBound(desiredPtsSeconds);
893-
return (lastDecodedAvFrameIndex + 1 == targetFrameIndex);
894-
}
895-
896888
/*
897889
Videos have I frames and non-I frames (P and B frames). Non-I frames need data
898890
from the previous I frame to be decoded.
@@ -918,9 +910,13 @@ I P P P I P P P I P P I P P I P
918910
919911
(2) is more efficient than (1) if there is an I frame between x and y.
920912
*/
921-
bool VideoDecoder::canWeAvoidSeekingVideo(int64_t targetPts) const {
922-
int64_t lastDecodedAvFramePts =
923-
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
913+
bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
914+
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
915+
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
916+
return true;
917+
}
918+
919+
int64_t lastDecodedAvFramePts = streamInfo.lastDecodedAvFramePts;
924920
if (targetPts < lastDecodedAvFramePts) {
925921
// We can never skip a seek if we are seeking backwards.
926922
return false;
@@ -954,16 +950,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
954950

955951
decodeStats_.numSeeksAttempted++;
956952

957-
// TODO_CODE_QUALITY The different signature is unfortunate
958-
bool canAvoidSeeking = false;
959-
auto avMediaType = streamInfos_.at(activeStreamIndex_).avMediaType;
960-
if (avMediaType == AVMEDIA_TYPE_AUDIO) {
961-
canAvoidSeeking = canWeAvoidSeekingAudio(*desiredPtsSeconds_);
962-
} else {
963-
canAvoidSeeking = canWeAvoidSeekingVideo(desiredPts);
964-
}
965-
966-
if (canAvoidSeeking) {
953+
if (canWeAvoidSeeking(desiredPts)) {
967954
decodeStats_.numSeeksSkipped++;
968955
return;
969956
}
@@ -973,85 +960,13 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
973960
// the key frame that we want to seek to.
974961
// See https://github.com/pytorch/torchcodec/issues/179 for more details.
975962
// See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
976-
if (avMediaType == AVMEDIA_TYPE_VIDEO && !streamInfo.keyFrames.empty()) {
963+
if (!streamInfo.keyFrames.empty()) {
977964
int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex(
978965
streamInfo.keyFrames, desiredPts);
979966
desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0);
980967
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
981968
}
982969

983-
if (avMediaType == AVMEDIA_TYPE_AUDIO) {
984-
desiredPts -= 1;
985-
// Note [Seek offset for audio]
986-
//
987-
// There is a strange FFmpeg behavior when decoding audio frames: seeking at
988-
// a frame start and then flushing buffers with avcodec_flush_buffers (as is
989-
// recommended by the FFmpeg docs) leads to the samples to be decoded
990-
// incorrectly. It's difficult to really determine what's going on, but the
991-
// fact is that there exist a data dependency between frames: for frame `i`
992-
// to be correct, then the packet of frame `i-1` needs to be sent to the
993-
// decoder, and there must be no flushing in-between. The naive (and
994-
// incorrect) fix of just *not* flushing only works when we're decoding
995-
// consecutive frames, but fails when decoding non-consecutive frames. We
996-
// try to mitigate this issue via two different means:
997-
// - A. We try to avoid seeking (and thus flushing) as much as possible.
998-
// Typically, we don't need to seek if we want frame `i` and we just
999-
// decoded frame `i - 1`: we just need to return the next frame. This
1000-
// happens in the logic of `canWeAvoidSeekingAudio()`.
1001-
// - B. Instead of seeking to desiredPts, we seek to desiredPts - 1.
1002-
// Effectively, this leads us to decode frame `i-1` before decoding frame
1003-
// `i`. Our `filterFunction` logic in `decodeAVFrame()` ensures that we
1004-
// are returning frame `i` (and not `i - 1`), and because we just decoded
1005-
// frame `i-1`, frame `i` is correct.
1006-
//
1007-
// Strategy B works most of the time: in most decoding APIs, we first
1008-
// convert a frame's pts to an index, and then use that corresponding
1009-
// index's pts to decide where to seek. This means that `desiredPts` usually
1010-
// lands *exactly* where frame `i` starts, and `desiredPts - 1` is the last
1011-
// pts of frame `i-1`, so we do end up seeking (as desired) to frame `i-1`.
1012-
// But, there are cases where this offset trick won't work: if `desiredPts`
1013-
// isn't exactly at a frame's beginning. This corresponds to the following
1014-
// scenarios:
1015-
// - When calling any API in approximate mode *and* if the framerate isn't
1016-
// constant. Because the framerate isn't constant, it's likely that the
1017-
// index won't be correct, and that the index -> pts conversion won't land
1018-
// exactly at a frame start either.
1019-
// - When calling `getFramePlayedAt(pts)`, regardless of the mode, if `pts`
1020-
// doesn't land exactly at a frame's start. We have tests that currently
1021-
// exhibit this behavior: test_get_frame_at_pts_audio_bad(). The "obvious"
1022-
// fix for this is to let `getFramePlayedAt` convert the pts to an index,
1023-
// just like the rest of the APIs.
1024-
//
1025-
// TODO HOW DO WE ADDRESS THIS??
1026-
//
1027-
// A few more notes:
1028-
// - This offset trick does work for the first frame at pts=0: we'll seek to
1029-
// -1, and this leads to a first packet with pts=-1024 to be sent to the
1030-
// decoder (on our test data), leading to frame 0 to be correctly decoded.
1031-
// - The data dependency / buffer flushing issue can be observed on
1032-
// compressed formats like aac or mp3. It doesn't happen on uncompressed
1033-
// formats like wav, where the decoder's buffers are likely unused. We
1034-
// could skip this entire logic for such formats.
1035-
// - All this *seems* to be related to this 13yo+ issue:
1036-
// https://stackoverflow.com/questions/7989623/ffmpeg-seeking-brings-audio-artifacts
1037-
// But according to the thread, the problem there (which has been fixed)
1038-
// seemed to be **lack** of flushing.
1039-
// - So far we have only observed a data-dependency of 1 frame: we need to
1040-
// decode frame `i-1` to decode `i`. It's possible that there exist
1041-
// longer data dependencies of more than 1 frame on other videos /
1042-
// formats. We just haven't observed those yet. If this happens to be the
1043-
// case, then we have a much harder problem to solve.
1044-
// - This weird FFmpeg behavior is observable not just in Torchcodec, it
1045-
// really seems to be an FFmpeg thing. Other decoders have the same
1046-
// problem, like the ones in TorchVision. Those who do not exhibit this
1047-
// behavior are solving it in inefficient ways: Decord effectively decodes
1048-
// and caches the *entire* file when it is created, thus resolving the
1049-
// data dependency. Similarly, TorchAudio effectively always decodes all
1050-
// frames up to frame `i`, even after seeking to frame `i`, because it
1051-
// sets the 'backwards' flag when it calls `av_seek_frame`: it actually
1052-
// always seeks back to the beginning.
1053-
}
1054-
1055970
int ffmepgStatus = avformat_seek_file(
1056971
formatContext_.get(),
1057972
streamInfo.streamIndex,
@@ -1130,6 +1045,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
11301045
if (ffmpegStatus == AVERROR_EOF) {
11311046
// End of file reached. We must drain the codec by sending a nullptr
11321047
// packet.
1048+
11331049
ffmpegStatus = avcodec_send_packet(
11341050
streamInfo.codecContext.get(),
11351051
/*avpkt=*/nullptr);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,7 @@ class VideoDecoder {
365365
// DECODING APIS AND RELATED UTILS
366366
// --------------------------------------------------------------------------
367367

368-
bool canWeAvoidSeekingVideo(int64_t targetPts) const;
369-
bool canWeAvoidSeekingAudio(double desiredPtsSeconds) const;
368+
bool canWeAvoidSeeking(int64_t targetPts) const;
370369

371370
void maybeSeekToBeforeDesiredPts();
372371

@@ -487,6 +486,7 @@ class VideoDecoder {
487486
bool scannedAllStreams_ = false;
488487
// Tracks that we've already been initialized.
489488
bool initialized_ = false;
489+
bool alreadyCalledGetFramesPlayedInRange_ = false;
490490
};
491491

492492
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ void add_audio_stream(
232232
}
233233

234234
void seek_to_pts(at::Tensor& decoder, double seconds) {
235+
// TODO we should prevent more than one call to this op for audio streams, for
236+
// the same reasons we do so for getFramesPlayedInRange(). But we can't
237+
// implement the logic here, because we don't know media type (audio vs
238+
// video). We also can't do it within setCursorPtsInSeconds because it's used
239+
// by all other decoding methods.
240+
// This isn't un-doable, just not easy with the API we currently have.
241+
235242
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
236243
videoDecoder->setCursorPtsInSeconds(seconds);
237244
}

test/decoders/test_video_decoder_ops.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,94 @@ def test_audio_bad_method(self, method):
638638
):
639639
method(decoder)
640640

641+
@pytest.mark.parametrize(
642+
"start_seconds, stop_seconds",
643+
(
644+
# Beginning to end
645+
(0, 13.05),
646+
# At frames boundaries. Frame duration is exactly 0.064 seconds for
647+
# NASA_AUDIO. Need artifial -1e-5 for upper-bound to align the
648+
# reference_frames with the frames returned by the decoder, where
649+
# the interval is half-open.
650+
(0.064 * 4, 0.064 * 20 - 1e-5),
651+
# Not at frames boundaries
652+
(2, 4),
653+
),
654+
)
655+
def test_audio_get_frames_by_pts_in_range(self, start_seconds, stop_seconds):
656+
decoder = create_from_file(str(NASA_AUDIO.path))
657+
add_audio_stream(decoder)
658+
659+
reference_frames = NASA_AUDIO.get_frame_data_by_range(
660+
start=NASA_AUDIO.pts_to_frame_index(start_seconds),
661+
stop=NASA_AUDIO.pts_to_frame_index(stop_seconds) + 1,
662+
)
663+
frames, _, _ = get_frames_by_pts_in_range(
664+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
665+
)
666+
667+
assert_frames_equal(frames, reference_frames)
668+
669+
def test_audio_get_frames_by_pts_in_range_multiple_calls(self):
670+
decoder = create_from_file(str(NASA_AUDIO.path))
671+
add_audio_stream(decoder)
672+
673+
get_frames_by_pts_in_range(decoder, start_seconds=0, stop_seconds=1)
674+
with pytest.raises(
675+
RuntimeError, match="Can only decode once with audio stream"
676+
):
677+
get_frames_by_pts_in_range(decoder, start_seconds=0, stop_seconds=1)
678+
679+
def test_audio_seek_and_next(self):
680+
decoder = create_from_file(str(NASA_AUDIO.path))
681+
add_audio_stream(decoder)
682+
683+
pts = 2
684+
# Need +1 because we're not at frames boundaries
685+
reference_frame = NASA_AUDIO.get_frame_data_by_index(
686+
NASA_AUDIO.pts_to_frame_index(pts) + 1
687+
)
688+
seek_to_pts(decoder, pts)
689+
frame, _, _ = get_next_frame(decoder)
690+
assert_frames_equal(frame, reference_frame)
691+
692+
# Seeking forward is OK
693+
pts = 4
694+
reference_frame = NASA_AUDIO.get_frame_data_by_index(
695+
NASA_AUDIO.pts_to_frame_index(pts) + 1
696+
)
697+
seek_to_pts(decoder, pts)
698+
frame, _, _ = get_next_frame(decoder)
699+
assert_frames_equal(frame, reference_frame)
700+
701+
# Seeking backwards doesn't error, but it's wrong. See TODO in
702+
# `seek_to_pts` op.
703+
prev_pts = pts
704+
pts = 1
705+
seek_to_pts(decoder, pts)
706+
frame, _, _ = get_next_frame(decoder)
707+
# the decoder actually didn't seek, so the frame we're getting is just
708+
# the "next: one without seeking. This assertion exists to illutrate
709+
# what currently hapens, but it's obviously *wrong*.
710+
reference_frame = NASA_AUDIO.get_frame_data_by_index(
711+
NASA_AUDIO.pts_to_frame_index(prev_pts) + 2
712+
)
713+
assert_frames_equal(frame, reference_frame)
714+
715+
# def test_audio_seek_and_next_backwards(self):
716+
# decoder = create_from_file(str(NASA_AUDIO.path))
717+
# add_audio_stream(decoder)
718+
719+
# for pts in (4.5, 2):
720+
# # Need +1 because we're not at frames boundaries
721+
# reference_frame = NASA_AUDIO.get_frame_data_by_index(NASA_AUDIO.pts_to_frame_index(pts) + 1)
722+
# seek_to_pts(decoder, pts)
723+
# frame, _, _ = get_next_frame(decoder)
724+
# # assert_frames_equal(frame, reference_frame)
725+
726+
# reference_frame = NASA_AUDIO.get_frame_data_by_index(NASA_AUDIO.pts_to_frame_index(4.5) + 2)
727+
# assert_frames_equal(frame, reference_frame)
728+
641729

642730
if __name__ == "__main__":
643731
pytest.main()

test/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,13 @@ def get_frame_data_by_index(
356356

357357
return self._reference_frames[idx]
358358

359+
def pts_to_frame_index(self, pts_seconds: float) -> int:
360+
# These are hard-coded value assuming stream 4 of nasa_13013.mp4. Each
361+
# of the 204 frames contains 1024 samples.
362+
# TODO make this more generic
363+
frame_duration_seconds = 0.064
364+
return int(pts_seconds // frame_duration_seconds)
365+
359366
# TODO: this shouldn't be named chw. Also values are hard-coded
360367
@property
361368
def empty_chw_tensor(self) -> torch.Tensor:

0 commit comments

Comments
 (0)