Skip to content

Commit f56b259

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into file_like
2 parents fa2445e + b32aabe commit f56b259

File tree

4 files changed

+51
-25
lines changed

4 files changed

+51
-25
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ void VideoDecoder::addStream(
465465
TORCH_CHECK_EQ(retVal, AVSUCCESS);
466466

467467
streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0);
468+
streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base;
468469

469470
// TODO_CODE_QUALITY same as above.
470471
if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) {
@@ -564,13 +565,15 @@ void VideoDecoder::addAudioStream(int streamIndex) {
564565

565566
VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
566567
auto output = getNextFrameInternal();
567-
output.data = maybePermuteHWC2CHW(output.data);
568+
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
569+
output.data = maybePermuteHWC2CHW(output.data);
570+
}
568571
return output;
569572
}
570573

571574
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
572575
std::optional<torch::Tensor> preAllocatedOutputTensor) {
573-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
576+
validateActiveStream();
574577
AVFrameStream avFrameStream = decodeAVFrame(
575578
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
576579
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
@@ -866,7 +869,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
866869
// If we need to seek backwards, then we have to seek back to the beginning
867870
// of the stream.
868871
// TODO-AUDIO: document why this is needed in a big comment.
869-
setCursorPtsInSeconds(INT64_MIN);
872+
setCursorPtsInSecondsInternal(INT64_MIN);
870873
}
871874

872875
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
@@ -912,6 +915,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
912915
// --------------------------------------------------------------------------
913916

914917
void VideoDecoder::setCursorPtsInSeconds(double seconds) {
918+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
919+
setCursorPtsInSecondsInternal(seconds);
920+
}
921+
922+
void VideoDecoder::setCursorPtsInSecondsInternal(double seconds) {
915923
cursorWasJustSet_ = true;
916924
cursor_ =
917925
secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ class VideoDecoder {
372372
// DECODING APIS AND RELATED UTILS
373373
// --------------------------------------------------------------------------
374374

375+
void setCursorPtsInSecondsInternal(double seconds);
375376
bool canWeAvoidSeeking() const;
376377

377378
void maybeSeekToBeforeDesiredPts();

test/decoders/test_decoders.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -990,13 +990,7 @@ def test_get_all_samples(self, asset, stop_seconds):
990990
torch.testing.assert_close(samples.data, reference_frames)
991991
assert samples.sample_rate == asset.sample_rate
992992

993-
# TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553
994-
expected_pts = (
995-
0.072
996-
if asset is NASA_AUDIO_MP3
997-
else asset.get_frame_info(idx=0).pts_seconds
998-
)
999-
assert samples.pts_seconds == expected_pts
993+
assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds
1000994

1001995
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1002996
def test_at_frame_boundaries(self, asset):
@@ -1060,12 +1054,8 @@ def test_start_equals_stop(self, asset):
10601054
assert samples.data.shape == (0, 0)
10611055

10621056
def test_frame_start_is_not_zero(self):
1063-
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1].
1057+
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125.
10641058
# So if we request start = 0.05, we shouldn't be truncating anything.
1065-
#
1066-
# [1] well, really it's at 0.138125, not 0.072 (see
1067-
# https://github.com/pytorch/torchcodec/issues/553), but for the purpose
1068-
# of this test it doesn't matter.
10691059

10701060
asset = NASA_AUDIO_MP3
10711061
start_seconds = 0.05 # this is less than the first frame's pts

test/decoders/test_ops.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ class TestAudioOps:
631631
partial(get_frames_in_range, start=4, stop=5),
632632
partial(get_frame_at_pts, seconds=2),
633633
partial(get_frames_by_pts, timestamps=[0, 1.5]),
634-
partial(get_next_frame),
634+
partial(seek_to_pts, seconds=5),
635635
),
636636
)
637637
def test_audio_bad_method(self, method):
@@ -647,6 +647,22 @@ def test_audio_bad_seek_mode(self):
647647
):
648648
add_audio_stream(decoder)
649649

650+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
651+
def test_next(self, asset):
652+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
653+
add_audio_stream(decoder)
654+
655+
frame_index = 0
656+
while True:
657+
try:
658+
frame, *_ = get_next_frame(decoder)
659+
except IndexError:
660+
break
661+
torch.testing.assert_close(
662+
frame, asset.get_frame_data_by_index(frame_index)
663+
)
664+
frame_index += 1
665+
650666
@pytest.mark.parametrize(
651667
"range",
652668
(
@@ -831,6 +847,8 @@ def get_reference_frames(start_seconds, stop_seconds):
831847

832848
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
833849
def test_pts(self, asset):
850+
# Non-regression test for
851+
# https://github.com/pytorch/torchcodec/issues/553
834852
decoder = create_from_file(str(asset.path), seek_mode="approximate")
835853
add_audio_stream(decoder)
836854

@@ -845,15 +863,24 @@ def test_pts(self, asset):
845863
frames, asset.get_frame_data_by_index(frame_index)
846864
)
847865

848-
if asset is NASA_AUDIO_MP3 and frame_index == 0:
849-
# TODO This is a bug. The 0.138125 is correct while 0.072 is
850-
# incorrect, even though it comes from the decoded AVFrame's pts
851-
# field.
852-
# See https://github.com/pytorch/torchcodec/issues/553
853-
assert pts_seconds == 0.072
854-
assert start_seconds == 0.138125
855-
else:
856-
assert pts_seconds == start_seconds
866+
assert pts_seconds == start_seconds
867+
868+
def test_decode_before_frame_start(self):
869+
# Test illustrating bug described in
870+
# https://github.com/pytorch/torchcodec/issues/567
871+
asset = NASA_AUDIO_MP3
872+
873+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
874+
add_audio_stream(decoder)
875+
876+
frames, *_ = get_frames_by_pts_in_range_audio(
877+
decoder, start_seconds=0, stop_seconds=0.05
878+
)
879+
all_frames, *_ = get_frames_by_pts_in_range_audio(
880+
decoder, start_seconds=0, stop_seconds=None
881+
)
882+
# TODO fix this. `frames` should be empty.
883+
torch.testing.assert_close(frames, all_frames)
857884

858885

859886
if __name__ == "__main__":

0 commit comments

Comments
 (0)