Skip to content

Commit ce04b3c

Browse files
committed
Fallback to DTS when PTS info is missing
1 parent 4ba7db2 commit ce04b3c

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,16 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
2626
std::round(seconds * timeBase.den / timeBase.num));
2727
}
2828

29-
int64_t getPtsOrDts(const UniqueAVFrame& avFrame) {
30-
return avFrame->pts == INT64_MIN? avFrame->pkt_dts : avFrame->pts;
29+
// Some videos aren't properly encoded and do not specify pts values (for
30+
// packets, and thus for frames). Unset values correspond to INT64_MIN. When
31+
// that happens, we fall-back to the dts value which hopefully exists and is
32+
// correct.
33+
int64_t getPtsOrDts(ReferenceAVPacket& packet) {
34+
return packet->pts == INT64_MIN ? packet->dts : packet->pts;
3135
}
32-
int64_t getPtsOrDts(ReferenceAVPacket& packet){
33-
return packet->pts == INT64_MIN? packet->dts : packet->pts;
36+
37+
int64_t getPtsOrDts(const UniqueAVFrame& avFrame) {
38+
return avFrame->pts == INT64_MIN ? avFrame->pkt_dts : avFrame->pts;
3439
}
3540

3641
} // namespace
@@ -499,8 +504,9 @@ FrameOutput SingleStreamDecoder::getNextFrame() {
499504
FrameOutput SingleStreamDecoder::getNextFrameInternal(
500505
std::optional<torch::Tensor> preAllocatedOutputTensor) {
501506
validateActiveStream();
502-
UniqueAVFrame avFrame = decodeAVFrame(
503-
[this](const UniqueAVFrame& avFrame) { return getPtsOrDts(avFrame) >= cursor_; });
507+
UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) {
508+
return getPtsOrDts(avFrame) >= cursor_;
509+
});
504510
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
505511
}
506512

@@ -620,7 +626,6 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
620626

621627
FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
622628
validateActiveStream(AVMEDIA_TYPE_VIDEO);
623-
printf("seconds=%f\n", seconds);
624629
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
625630
double frameStartTime =
626631
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
@@ -632,24 +637,15 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
632637
// don't cache it locally, we have to rewind back.
633638
seconds = frameStartTime;
634639
}
635-
printf("seconds=%f\n", seconds);
636-
printf("frameStartTime=%f\n", frameStartTime);
637-
printf("frameEndTime=%f\n", frameEndTime);
638-
printf("TimeBase: %d %d\n", streamInfo.timeBase.num, streamInfo.timeBase.den);
639-
printf("In decoding loop\n");
640640

641641
setCursorPtsInSeconds(seconds);
642642
UniqueAVFrame avFrame =
643643
decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
644644
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
645-
double frameStartTime = ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase);
645+
double frameStartTime =
646+
ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase);
646647
double frameEndTime = ptsToSeconds(
647648
getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase);
648-
printf("frame pts=%ld\n", avFrame->pts);
649-
printf("frame pkt_dts=%ld\n", avFrame->pkt_dts);
650-
printf("frameStartTime=%f\n", frameStartTime);
651-
printf("frameEndTime=%f\n", frameEndTime);
652-
printf("\n");
653649
if (frameStartTime > seconds) {
654650
// FFMPEG seeked past the frame we are looking for even though we
655651
// set max_ts to be our needed timestamp in avformat_seek_file()
@@ -1164,7 +1160,8 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
11641160
FrameOutput frameOutput;
11651161
auto& streamInfo = streamInfos_[activeStreamIndex_];
11661162
frameOutput.ptsSeconds = ptsToSeconds(
1167-
getPtsOrDts(avFrame), formatContext_->streams[activeStreamIndex_]->time_base);
1163+
getPtsOrDts(avFrame),
1164+
formatContext_->streams[activeStreamIndex_]->time_base);
11681165
frameOutput.durationSeconds = ptsToSeconds(
11691166
getDuration(avFrame),
11701167
formatContext_->streams[activeStreamIndex_]->time_base);

src/torchcodec/_core/_metadata.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def average_fps(self) -> Optional[float]:
142142
self.end_stream_seconds_from_content is None
143143
or self.begin_stream_seconds_from_content is None
144144
or self.num_frames is None
145+
# Should never happen, but prevents ZeroDivisionError:
146+
or self.end_stream_seconds_from_content
147+
== self.begin_stream_seconds_from_content
145148
):
146149
return self.average_fps_from_header
147150
return self.num_frames / (

0 commit comments

Comments
 (0)