Skip to content

Commit 37fe2b0

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into srccc
2 parents 0108506 + 66c78c2 commit 37fe2b0

File tree

5 files changed

+77
-31
lines changed

5 files changed

+77
-31
lines changed

README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@ We achieve these capabilities through:
1818
* Returning data as PyTorch tensors, ready to be fed into PyTorch transforms
1919
or used directly to train models.
2020

21-
> [!NOTE]
22-
> ⚠️ TorchCodec is still in development stage and some APIs may be updated
23-
> in future versions, depending on user feedback.
24-
> If you have any suggestions or issues, please let us know by
25-
> [opening an issue](https://github.com/pytorch/torchcodec/issues/new/choose)!
26-
2721
## Using TorchCodec
2822

2923
Here's a condensed summary of what you can do with TorchCodec. For more detailed

docs/source/index.rst

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,6 @@ We achieve these capabilities through:
6767

6868
How to sample regular and random clips from a video
6969

70-
.. note::
71-
72-
TorchCodec is still in development stage and we are actively seeking
73-
feedback. If you have any suggestions or issues, please let us know by
74-
`opening an issue <https://github.com/pytorch/torchcodec/issues/new/choose>`_
75-
on our `GitHub repository <https://github.com/pytorch/torchcodec/>`_.
76-
7770
.. toctree::
7871
:maxdepth: 1
7972
:caption: TorchCodec documentation

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,27 @@
1717
namespace facebook::torchcodec {
1818
namespace {
1919

20-
double ptsToSeconds(int64_t pts, int den) {
21-
return static_cast<double>(pts) / den;
22-
}
23-
2420
double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
25-
return ptsToSeconds(pts, timeBase.den);
21+
return static_cast<double>(pts) * timeBase.num / timeBase.den;
2622
}
2723

2824
int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
29-
return static_cast<int64_t>(std::round(seconds * timeBase.den));
25+
return static_cast<int64_t>(
26+
std::round(seconds * timeBase.den / timeBase.num));
27+
}
28+
29+
// Some videos aren't properly encoded and do not specify pts values for
30+
// packets, and thus for frames. Unset values correspond to INT64_MIN. When that
31+
// happens, we fallback to the dts value which hopefully exists and is correct.
32+
// Accessing AVFrames and AVPackets's pts values should **always** go through
33+
// the helpers below. Then, the "pts" fields in our structs like FrameInfo.pts
34+
// should be interpreted as "pts if it exists, dts otherwise".
35+
int64_t getPtsOrDts(ReferenceAVPacket& packet) {
36+
return packet->pts == INT64_MIN ? packet->dts : packet->pts;
37+
}
38+
39+
int64_t getPtsOrDts(const UniqueAVFrame& avFrame) {
40+
return avFrame->pts == INT64_MIN ? avFrame->pkt_dts : avFrame->pts;
3041
}
3142

3243
} // namespace
@@ -151,8 +162,9 @@ void SingleStreamDecoder::initializeDecoder() {
151162
}
152163

153164
if (formatContext_->duration > 0) {
165+
AVRational defaultTimeBase{1, AV_TIME_BASE};
154166
containerMetadata_.durationSeconds =
155-
ptsToSeconds(formatContext_->duration, AV_TIME_BASE);
167+
ptsToSeconds(formatContext_->duration, defaultTimeBase);
156168
}
157169

158170
if (formatContext_->bit_rate > 0) {
@@ -225,16 +237,16 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
225237
int streamIndex = packet->stream_index;
226238
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
227239
streamMetadata.minPtsFromScan = std::min(
228-
streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts);
240+
streamMetadata.minPtsFromScan.value_or(INT64_MAX), getPtsOrDts(packet));
229241
streamMetadata.maxPtsFromScan = std::max(
230242
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
231-
packet->pts + packet->duration);
243+
getPtsOrDts(packet) + packet->duration);
232244
streamMetadata.numFramesFromScan =
233245
streamMetadata.numFramesFromScan.value_or(0) + 1;
234246

235247
// Note that we set the other value in this struct, nextPts, only after
236248
// we have scanned all packets and sorted by pts.
237-
FrameInfo frameInfo = {packet->pts};
249+
FrameInfo frameInfo = {getPtsOrDts(packet)};
238250
if (packet->flags & AV_PKT_FLAG_KEY) {
239251
frameInfo.isKeyFrame = true;
240252
streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
@@ -495,8 +507,9 @@ FrameOutput SingleStreamDecoder::getNextFrame() {
495507
FrameOutput SingleStreamDecoder::getNextFrameInternal(
496508
std::optional<torch::Tensor> preAllocatedOutputTensor) {
497509
validateActiveStream();
498-
UniqueAVFrame avFrame = decodeAVFrame(
499-
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
510+
UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) {
511+
return getPtsOrDts(avFrame) >= cursor_;
512+
});
500513
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
501514
}
502515

@@ -632,9 +645,10 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
632645
UniqueAVFrame avFrame =
633646
decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
634647
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
635-
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
648+
double frameStartTime =
649+
ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase);
636650
double frameEndTime = ptsToSeconds(
637-
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
651+
getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase);
638652
if (frameStartTime > seconds) {
639653
// FFMPEG seeked past the frame we are looking for even though we
640654
// set max_ts to be our needed timestamp in avformat_seek_file()
@@ -861,8 +875,8 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
861875
try {
862876
UniqueAVFrame avFrame =
863877
decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) {
864-
return startPts < avFrame->pts + getDuration(avFrame) &&
865-
stopPts > avFrame->pts;
878+
return startPts < getPtsOrDts(avFrame) + getDuration(avFrame) &&
879+
stopPts > getPtsOrDts(avFrame);
866880
});
867881
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
868882
if (!firstFramePtsSeconds.has_value()) {
@@ -1132,7 +1146,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11321146
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
11331147
// av_receive_frame() or the user will have seeked to a different location in
11341148
// the file and that will flush the decoder.
1135-
streamInfo.lastDecodedAvFramePts = avFrame->pts;
1149+
streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame);
11361150
streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
11371151

11381152
return avFrame;
@@ -1149,7 +1163,8 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
11491163
FrameOutput frameOutput;
11501164
auto& streamInfo = streamInfos_[activeStreamIndex_];
11511165
frameOutput.ptsSeconds = ptsToSeconds(
1152-
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
1166+
getPtsOrDts(avFrame),
1167+
formatContext_->streams[activeStreamIndex_]->time_base);
11531168
frameOutput.durationSeconds = ptsToSeconds(
11541169
getDuration(avFrame),
11551170
formatContext_->streams[activeStreamIndex_]->time_base);

src/torchcodec/_core/_metadata.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def average_fps(self) -> Optional[float]:
142142
self.end_stream_seconds_from_content is None
143143
or self.begin_stream_seconds_from_content is None
144144
or self.num_frames is None
145+
# Should never happen, but prevents ZeroDivisionError:
146+
or self.end_stream_seconds_from_content
147+
== self.begin_stream_seconds_from_content
145148
):
146149
return self.average_fps_from_header
147150
return self.num_frames / (

test/test_decoders.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,47 @@ def get_some_frames(decoder):
986986
assert_frames_equal(ref_frame3, frames[1].data)
987987
assert_frames_equal(ref_frame5, frames[2].data)
988988

989+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
990+
def test_pts_to_dts_fallback(self, seek_mode):
991+
# Non-regression test for
992+
# https://github.com/pytorch/torchcodec/issues/677 and
993+
# https://github.com/pytorch/torchcodec/issues/676.
994+
# More accurately, this is a non-regression test for videos which do
995+
# *not* specify pts values (all pts values are N/A and set to
996+
# INT64_MIN), but specify *dts* value - which we fallback to.
997+
#
998+
# The test video we have is from
999+
# https://huggingface.co/datasets/raushan-testing-hf/videos-test/blob/main/sample_video_2.avi
1000+
# We can't check it into the repo due to potential licensing issues, so
1001+
# we have to unconditionally skip this test.#
1002+
# TODO: encode a video with no pts values to unskip this test. Couldn't
1003+
# find a way to do that with FFmpeg's CLI, but this should be doable
1004+
# once we have our own video encoder.
1005+
pytest.skip(reason="TODO: Need video with no pts values.")
1006+
1007+
path = "/home/nicolashug/Downloads/sample_video_2.avi"
1008+
decoder = VideoDecoder(path, seek_mode=seek_mode)
1009+
metadata = decoder.metadata
1010+
1011+
assert metadata.average_fps == pytest.approx(29.916667)
1012+
assert metadata.duration_seconds_from_header == 9.02507
1013+
assert metadata.duration_seconds == 9.02507
1014+
assert metadata.begin_stream_seconds_from_content == (
1015+
None if seek_mode == "approximate" else 0
1016+
)
1017+
assert metadata.end_stream_seconds_from_content == (
1018+
None if seek_mode == "approximate" else 9.02507
1019+
)
1020+
1021+
assert decoder[0].shape == (3, 240, 320)
1022+
decoder[10].shape == (3, 240, 320)
1023+
decoder.get_frame_at(2).data.shape == (3, 240, 320)
1024+
decoder.get_frames_at([2, 10]).data.shape == (2, 3, 240, 320)
1025+
decoder.get_frame_played_at(9).data.shape == (3, 240, 320)
1026+
decoder.get_frames_played_at([2, 4]).data.shape == (2, 3, 240, 320)
1027+
with pytest.raises(AssertionError, match="not equal"):
1028+
torch.testing.assert_close(decoder[0], decoder[10])
1029+
9891030

9901031
class TestAudioDecoder:
9911032
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

0 commit comments

Comments
 (0)