Skip to content

Commit 6c7e31f

Browse files
committed
tons of comments
1 parent a2b10ca commit 6c7e31f

File tree

2 files changed

+132
-9
lines changed

2 files changed

+132
-9
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,13 @@ bool VideoDecoder::canWeAvoidSeekingAudio(double desiredPtsSeconds) const {
877877
return false;
878878
}
879879

880+
// We can skip seeking if we want to decoder frame `i` and we just decoded
881+
// frame `i - 1`. Note this involves a `log(numFrames)` complexity for each
882+
// decoded frame.
883+
// TODO we should bypass this log(numFrames) logic when calling range APIs
884+
// where the step is 1, because we are sure in this case that all frames
885+
// (except the first one) are consecutive. See a POC at
886+
// https://github.com/pytorch/torchcodec/pull/514
880887
double lastDecodedAvFramePtsSeconds =
881888
ptsToSeconds(lastDecodedAvFramePts, streamInfo.timeBase);
882889
int64_t lastDecodedAvFrameIndex =
@@ -972,17 +979,81 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
972979
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
973980
}
974981

975-
// TODO explain this nasty hack
976-
// This probably only works if the desired pts corresponds exactly to a frame
977-
// start.
978-
int64_t offset = avMediaType == AVMEDIA_TYPE_VIDEO ? 0 : -1;
982+
if (avMediaType == AVMEDIA_TYPE_AUDIO) {
983+
desiredPts -= 1;
984+
// Note [Seek offset for audio]
985+
//
986+
// There is a strange FFmpeg behavior when decoding audio frames: seeking at
987+
// a frame start and then flushing buffers with avcodec_flush_buffers (as is
988+
// recommended by the FFmpeg docs) leads to the samples to be decoded
989+
// incorrectly. It's difficult to really determine what's going on, but the
990+
// fact is that there exist a data dependency between frames: for frame `i`
991+
// to be correct, then the packet of frame `i-1` needs to be sent to the
992+
// decoder, and there must be no flushing in-between. The naive (and
993+
// incorrect) fix of just *not* flushing only works when we're decoding
994+
// consecutive frames, but fails when decoding non-consecutive frames. We
995+
// try to mitigate this issue via two different means:
996+
// - A. We try to avoid seeking (and thus flushing) as much as possible.
997+
// Typically, we don't need to seek if we want frame `i` and we just
998+
// decoded frame `i - 1`: we just need to return the next frame. This
999+
// happens in the logic of `canWeAvoidSeekingAudio()`.
1000+
// - B. Instead of seeking to desiredPts, we seek to desiredPts - 1.
1001+
// Effectively, this leads us to decode frame `i-1` before decoding frame
1002+
// `i`. Our `filterFunction` logic in `decodeAVFrame()` ensures that we
1003+
// are returning frame `i` (and not `i - 1`), and because we just decoded
1004+
// frame `i-1`, frame `i` is correct.
1005+
//
1006+
// Strategy B works most of the time: in most decoding APIs, we first
1007+
// convert a frame's pts to an index, and then use that corresponding
1008+
// index's pts to decide where to seek. This means that `desiredPts` usually
1009+
// lands *exactly* where frame `i` starts, and `desiredPts - 1` is the last
1010+
// pts of frame `i-1`, so we do end up seeking (as desired) to frame `i-1`.
1011+
// But, there are cases where this offset trick won't work: if `desiredPts`
1012+
// isn't exactly at a frame's beginning. This corresponds to the following
1013+
// scenarios:
1014+
// - When calling any API in approximate mode *and* if the framerate isn't
1015+
// constant. Because the framerate isn't constant, it's likely that the
1016+
// index won't be correct, and that the index -> pts conversion won't land
1017+
// exactly at a frame start either.
1018+
// - When calling `getFramePlayedAt(pts)`, regardless of the mode, if `pts`
1019+
// doesn't land exactly at a frame's start. We have tests that currently
1020+
// exhibit this behavior: test_get_frame_at_pts_audio_bad().
1021+
// TODO HOW DO WE FIX THIS??
1022+
1023+
// A few notes:
1024+
// - This offset trick does work for the first frame at pts=0: we'll seek to
1025+
// -1, and this leads to a first packet with pts=-1024 to be sent to the
1026+
// decoder (on our test data), leading to frame 0 to be correctly decoded.
1027+
// - The data dependency / buffer flushing issue can be observed on
1028+
// compressed formats like aac or mp3. It doesn't happen on uncompressed
1029+
// formats like wav, where the decoder's buffers are likely unused. We
1030+
// could skip this entire logic for such formats.
1031+
// - All this *seems* to be related to this 13yo+ issue:
1032+
// https://stackoverflow.com/questions/7989623/ffmpeg-seeking-brings-audio-artifacts
1033+
// But according to the thread, the problem there (which has been fixed)
1034+
// seemed to be **lack** of flushing.
1035+
// - So far we have only observed a data-dependency of 1 frame: we need to
1036+
// decode frame `i-1` to decode `i`. It's possible that there exist
1037+
// longer data dependencies of more than 1 frame on other videos /
1038+
// formats. We just haven't observed those yet. If this happens to be the
1039+
// case, then we have a much harder problem to solve.
1040+
// - This weird FFmpeg behavior is observable not just in Torchcodec, it
1041+
// really seems to be an FFmpeg thing. Other decoders have the same
1042+
// problem, like the ones in TorchVision. Those who do not exhibit this
1043+
// behavior are solving it in inefficient ways: Decord effectively decodes
1044+
// and caches the *entire* file when it is created, thus resolving the
1045+
// data dependency. Similarly, TorchAudio effectively always decodes all
1046+
// frames up to frame `i`, even after seeking to frame `i`, because it
1047+
// sets the 'backwards' flag when it calls `av_seek_frame`: it actually
1048+
// always seeks back to the beginning.
1049+
}
9791050

9801051
int ffmepgStatus = avformat_seek_file(
9811052
formatContext_.get(),
9821053
streamInfo.streamIndex,
9831054
INT64_MIN,
984-
desiredPts + offset,
985-
desiredPts + offset,
1055+
desiredPts,
1056+
desiredPts,
9861057
0);
9871058

9881059
if (ffmepgStatus < 0) {

test/decoders/test_video_decoder_ops.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,29 @@ def test_get_frame_at_pts_audio(self, seek_mode):
195195
with pytest.raises(AssertionError):
196196
assert_frames_equal(next_frame, reference_frame6)
197197

198+
def test_get_frame_at_pts_audio_bad(self):
199+
decoder = create_from_file(str(NASA_AUDIO.path))
200+
add_audio_stream(decoder=decoder)
201+
202+
reference_frame6 = NASA_AUDIO.get_frame_data_by_index(
203+
INDEX_OF_AUDIO_FRAME_AFTER_SEEKING_AT_6
204+
)
205+
frame6, _, _ = get_frame_at_pts(decoder, 6.05)
206+
# See Note [Seek offset for audio].
207+
# The frame played at 6.05 should be the reference frame, but because
208+
# 6.05 isn't exactly the beginning of that frame, the samples are
209+
# decoded incorrectly.
210+
# TODO Fix this.
211+
with pytest.raises(AssertionError):
212+
assert_frames_equal(frame6, reference_frame6)
213+
214+
# And yet another quirk: if we try to decode it again, we actually end
215+
# up with the samples being correctly decoded. This is because we have a
216+
# custom logic within getFramePlayedAt() that resets desiredPts to the
217+
# pts of the beginning of the frame in some very specific cases.
218+
frame6, _, _ = get_frame_at_pts(decoder, 6.05)
219+
assert_frames_equal(frame6, reference_frame6)
220+
198221
@pytest.mark.parametrize("test_ref", (NASA_VIDEO, NASA_AUDIO))
199222
@pytest.mark.parametrize("device", cpu_and_cuda())
200223
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
@@ -779,34 +802,48 @@ def test_cuda_decoder(self):
779802
)
780803

781804
def test_get_same_frame_twice(self):
805+
# Non-regression tests that were useful while developing audio support.
782806
def make_decoder():
783807
decoder = create_from_file(str(NASA_AUDIO.path))
784808
add_audio_stream(decoder)
785809
return decoder
786810

787811
for frame_index in (0, 10, 15):
812+
ref = NASA_AUDIO.get_frame_data_by_index(frame_index)
813+
788814
decoder = make_decoder()
789815
a = get_frame_at_index(decoder, frame_index=frame_index)
790816
b = get_frame_at_index(decoder, frame_index=frame_index)
791817
torch.testing.assert_close(a, b)
818+
torch.testing.assert_close(a[0], ref)
792819

793820
decoder = make_decoder()
794821
a = get_frames_at_indices(decoder, frame_indices=[frame_index])
795822
b = get_frames_at_indices(decoder, frame_indices=[frame_index])
796823
torch.testing.assert_close(a, b)
824+
torch.testing.assert_close(a[0][0], ref)
797825

798826
decoder = make_decoder()
799827
a = get_frames_in_range(decoder, start=frame_index, stop=frame_index + 1)
800828
b = get_frames_in_range(decoder, start=frame_index, stop=frame_index + 1)
801829
torch.testing.assert_close(a, b)
830+
torch.testing.assert_close(a[0][0], ref)
802831

803-
pts_at_frame_start = 0
832+
pts_at_frame_start = 0 # 0 corresponds exactly to a frame start
833+
index_of_frame_at_0 = 0
804834
pts_not_at_frame_start = 2 # second 2 is in the middle of a frame
805-
for pts in (pts_at_frame_start, pts_not_at_frame_start):
835+
index_of_frame_at_2 = 31
836+
for pts, frame_index in (
837+
(pts_at_frame_start, index_of_frame_at_0),
838+
(pts_not_at_frame_start, index_of_frame_at_2),
839+
):
840+
ref = NASA_AUDIO.get_frame_data_by_index(frame_index)
841+
806842
decoder = make_decoder()
807843
a = get_frames_by_pts(decoder, timestamps=[pts])
808844
b = get_frames_by_pts(decoder, timestamps=[pts])
809845
torch.testing.assert_close(a, b)
846+
torch.testing.assert_close(a[0][0], ref)
810847

811848
decoder = make_decoder()
812849
a = get_frames_by_pts_in_range(
@@ -816,11 +853,15 @@ def make_decoder():
816853
decoder, start_seconds=pts, stop_seconds=pts + 1e-4
817854
)
818855
torch.testing.assert_close(a, b)
856+
torch.testing.assert_close(a[0][0], ref)
819857

820858
decoder = make_decoder()
821859
a = get_frame_at_pts(decoder, seconds=pts_at_frame_start)
822860
b = get_frame_at_pts(decoder, seconds=pts_at_frame_start)
823861
torch.testing.assert_close(a, b)
862+
torch.testing.assert_close(
863+
a[0], NASA_AUDIO.get_frame_data_by_index(index_of_frame_at_0)
864+
)
824865

825866
decoder = make_decoder()
826867
a_frame, a_pts, a_duration = get_frame_at_pts(
@@ -831,8 +872,17 @@ def make_decoder():
831872
)
832873
torch.testing.assert_close(a_pts, b_pts)
833874
torch.testing.assert_close(a_duration, b_duration)
875+
# TODO fix this. These checks should pass
834876
with pytest.raises(AssertionError):
835877
torch.testing.assert_close(a_frame, b_frame)
878+
with pytest.raises(AssertionError):
879+
torch.testing.assert_close(
880+
a_frame, NASA_AUDIO.get_frame_data_by_index(index_of_frame_at_2)
881+
)
882+
# But second time works ¯\_(ツ)_/¯A (see also test_get_frame_at_pts_audio_bad())
883+
torch.testing.assert_close(
884+
b_frame, NASA_AUDIO.get_frame_data_by_index(index_of_frame_at_2)
885+
)
836886

837887
decoder = make_decoder()
838888
seek_to_pts(decoder, pts_at_frame_start)
@@ -841,13 +891,15 @@ def make_decoder():
841891
b = get_next_frame(decoder)
842892
torch.testing.assert_close(a, b)
843893

844-
# TODO: Wait WTFFF, this should not pass
845894
decoder = make_decoder()
846895
seek_to_pts(decoder, seconds=pts_not_at_frame_start)
847896
a = get_next_frame(decoder)
848897
seek_to_pts(decoder, seconds=pts_not_at_frame_start)
849898
b = get_next_frame(decoder)
850899
torch.testing.assert_close(a, b)
900+
torch.testing.assert_close(
901+
a[0], NASA_AUDIO.get_frame_data_by_index(index_of_frame_at_2 + 1)
902+
)
851903

852904

853905
if __name__ == "__main__":

0 commit comments

Comments
 (0)