Skip to content

Commit 3da223c

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into fltp
2 parents db9ff94 + b32aabe commit 3da223c

File tree

4 files changed

+51
-25
lines changed

4 files changed

+51
-25
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ void VideoDecoder::addStream(
468468
TORCH_CHECK_EQ(retVal, AVSUCCESS);
469469

470470
streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0);
471+
streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base;
471472

472473
// TODO_CODE_QUALITY same as above.
473474
if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) {
@@ -573,13 +574,15 @@ void VideoDecoder::addAudioStream(int streamIndex) {
573574

574575
VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
575576
auto output = getNextFrameInternal();
576-
output.data = maybePermuteHWC2CHW(output.data);
577+
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
578+
output.data = maybePermuteHWC2CHW(output.data);
579+
}
577580
return output;
578581
}
579582

580583
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
581584
std::optional<torch::Tensor> preAllocatedOutputTensor) {
582-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
585+
validateActiveStream();
583586
AVFrameStream avFrameStream = decodeAVFrame(
584587
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
585588
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
@@ -875,7 +878,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
875878
// If we need to seek backwards, then we have to seek back to the beginning
876879
// of the stream.
877880
// TODO-AUDIO: document why this is needed in a big comment.
878-
setCursorPtsInSeconds(INT64_MIN);
881+
setCursorPtsInSecondsInternal(INT64_MIN);
879882
}
880883

881884
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
@@ -921,6 +924,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
921924
// --------------------------------------------------------------------------
922925

923926
void VideoDecoder::setCursorPtsInSeconds(double seconds) {
927+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
928+
setCursorPtsInSecondsInternal(seconds);
929+
}
930+
931+
void VideoDecoder::setCursorPtsInSecondsInternal(double seconds) {
924932
cursorWasJustSet_ = true;
925933
cursor_ =
926934
secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ class VideoDecoder {
371371
// DECODING APIS AND RELATED UTILS
372372
// --------------------------------------------------------------------------
373373

374+
void setCursorPtsInSecondsInternal(double seconds);
374375
bool canWeAvoidSeeking() const;
375376

376377
void maybeSeekToBeforeDesiredPts();

test/decoders/test_decoders.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -990,13 +990,7 @@ def test_get_all_samples(self, asset, stop_seconds):
990990
torch.testing.assert_close(samples.data, reference_frames)
991991
assert samples.sample_rate == asset.sample_rate
992992

993-
# TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553
994-
expected_pts = (
995-
0.072
996-
if asset is NASA_AUDIO_MP3
997-
else asset.get_frame_info(idx=0).pts_seconds
998-
)
999-
assert samples.pts_seconds == expected_pts
993+
assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds
1000994

1001995
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1002996
def test_at_frame_boundaries(self, asset):
@@ -1060,12 +1054,8 @@ def test_start_equals_stop(self, asset):
10601054
assert samples.data.shape == (0, 0)
10611055

10621056
def test_frame_start_is_not_zero(self):
1063-
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1].
1057+
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125.
10641058
# So if we request start = 0.05, we shouldn't be truncating anything.
1065-
#
1066-
# [1] well, really it's at 0.138125, not 0.072 (see
1067-
# https://github.com/pytorch/torchcodec/issues/553), but for the purpose
1068-
# of this test it doesn't matter.
10691059

10701060
asset = NASA_AUDIO_MP3
10711061
start_seconds = 0.05 # this is less than the first frame's pts

test/decoders/test_ops.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ class TestAudioOps:
626626
partial(get_frames_in_range, start=4, stop=5),
627627
partial(get_frame_at_pts, seconds=2),
628628
partial(get_frames_by_pts, timestamps=[0, 1.5]),
629-
partial(get_next_frame),
629+
partial(seek_to_pts, seconds=5),
630630
),
631631
)
632632
def test_audio_bad_method(self, method):
@@ -642,6 +642,22 @@ def test_audio_bad_seek_mode(self):
642642
):
643643
add_audio_stream(decoder)
644644

645+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
646+
def test_next(self, asset):
647+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
648+
add_audio_stream(decoder)
649+
650+
frame_index = 0
651+
while True:
652+
try:
653+
frame, *_ = get_next_frame(decoder)
654+
except IndexError:
655+
break
656+
torch.testing.assert_close(
657+
frame, asset.get_frame_data_by_index(frame_index)
658+
)
659+
frame_index += 1
660+
645661
@pytest.mark.parametrize(
646662
"range",
647663
(
@@ -826,6 +842,8 @@ def get_reference_frames(start_seconds, stop_seconds):
826842

827843
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
828844
def test_pts(self, asset):
845+
# Non-regression test for
846+
# https://github.com/pytorch/torchcodec/issues/553
829847
decoder = create_from_file(str(asset.path), seek_mode="approximate")
830848
add_audio_stream(decoder)
831849

@@ -840,15 +858,24 @@ def test_pts(self, asset):
840858
frames, asset.get_frame_data_by_index(frame_index)
841859
)
842860

843-
if asset is NASA_AUDIO_MP3 and frame_index == 0:
844-
# TODO This is a bug. The 0.138125 is correct while 0.072 is
845-
# incorrect, even though it comes from the decoded AVFrame's pts
846-
# field.
847-
# See https://github.com/pytorch/torchcodec/issues/553
848-
assert pts_seconds == 0.072
849-
assert start_seconds == 0.138125
850-
else:
851-
assert pts_seconds == start_seconds
861+
assert pts_seconds == start_seconds
862+
863+
def test_decode_before_frame_start(self):
864+
# Test illustrating bug described in
865+
# https://github.com/pytorch/torchcodec/issues/567
866+
asset = NASA_AUDIO_MP3
867+
868+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
869+
add_audio_stream(decoder)
870+
871+
frames, *_ = get_frames_by_pts_in_range_audio(
872+
decoder, start_seconds=0, stop_seconds=0.05
873+
)
874+
all_frames, *_ = get_frames_by_pts_in_range_audio(
875+
decoder, start_seconds=0, stop_seconds=None
876+
)
877+
# TODO fix this. `frames` should be empty.
878+
torch.testing.assert_close(frames, all_frames)
852879

853880

854881
if __name__ == "__main__":

0 commit comments

Comments
 (0)