Skip to content

Commit 9d6bcff

Browse files
committed
WIP
1 parent 721c315 commit 9d6bcff

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
2929
return static_cast<int64_t>(std::round(seconds * timeBase.den));
3030
}
3131

32+
int64_t getPtsOrDts(const UniqueAVFrame& avFrame) {
33+
return avFrame->pts == INT64_MIN? avFrame->pkt_dts : avFrame->pts;
34+
}
35+
int64_t getPtsOrDts(ReferenceAVPacket& packet){
36+
return packet->pts == INT64_MIN? packet->dts : packet->pts;
37+
}
38+
3239
} // namespace
3340

3441
// --------------------------------------------------------------------------
@@ -225,16 +232,16 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
225232
int streamIndex = packet->stream_index;
226233
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
227234
streamMetadata.minPtsFromScan = std::min(
228-
streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts);
235+
streamMetadata.minPtsFromScan.value_or(INT64_MAX), getPtsOrDts(packet));
229236
streamMetadata.maxPtsFromScan = std::max(
230237
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
231-
packet->pts + packet->duration);
238+
getPtsOrDts(packet) + packet->duration);
232239
streamMetadata.numFramesFromScan =
233240
streamMetadata.numFramesFromScan.value_or(0) + 1;
234241

235242
// Note that we set the other value in this struct, nextPts, only after
236243
// we have scanned all packets and sorted by pts.
237-
FrameInfo frameInfo = {packet->pts};
244+
FrameInfo frameInfo = {getPtsOrDts(packet)};
238245
if (packet->flags & AV_PKT_FLAG_KEY) {
239246
frameInfo.isKeyFrame = true;
240247
streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
@@ -496,7 +503,7 @@ FrameOutput SingleStreamDecoder::getNextFrameInternal(
496503
std::optional<torch::Tensor> preAllocatedOutputTensor) {
497504
validateActiveStream();
498505
UniqueAVFrame avFrame = decodeAVFrame(
499-
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
506+
[this](const UniqueAVFrame& avFrame) { return getPtsOrDts(avFrame) >= cursor_; });
500507
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
501508
}
502509

@@ -616,6 +623,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
616623

617624
FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
618625
validateActiveStream(AVMEDIA_TYPE_VIDEO);
626+
printf("seconds=%f\n", seconds);
619627
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
620628
double frameStartTime =
621629
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
@@ -627,14 +635,24 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
627635
// don't cache it locally, we have to rewind back.
628636
seconds = frameStartTime;
629637
}
638+
printf("seconds=%f\n", seconds);
639+
printf("frameStartTime=%f\n", frameStartTime);
640+
printf("frameEndTime=%f\n", frameEndTime);
641+
printf("TimeBase: %d %d\n", streamInfo.timeBase.num, streamInfo.timeBase.den);
642+
printf("In decoding loop\n");
630643

631644
setCursorPtsInSeconds(seconds);
632645
UniqueAVFrame avFrame =
633646
decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
634647
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
635-
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
648+
double frameStartTime = ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase);
636649
double frameEndTime = ptsToSeconds(
637-
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
650+
getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase);
651+
printf("frame pts=%ld\n", avFrame->pts);
652+
printf("frame pkt_dts=%ld\n", avFrame->pkt_dts);
653+
printf("frameStartTime=%f\n", frameStartTime);
654+
printf("frameEndTime=%f\n", frameEndTime);
655+
printf("\n");
638656
if (frameStartTime > seconds) {
639657
// FFMPEG seeked past the frame we are looking for even though we
640658
// set max_ts to be our needed timestamp in avformat_seek_file()
@@ -861,8 +879,8 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
861879
try {
862880
UniqueAVFrame avFrame =
863881
decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) {
864-
return startPts < avFrame->pts + getDuration(avFrame) &&
865-
stopPts > avFrame->pts;
882+
return startPts < getPtsOrDts(avFrame) + getDuration(avFrame) &&
883+
stopPts > getPtsOrDts(avFrame);
866884
});
867885
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
868886
if (!firstFramePtsSeconds.has_value()) {
@@ -1132,7 +1150,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11321150
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
11331151
// av_receive_frame() or the user will have seeked to a different location in
11341152
// the file and that will flush the decoder.
1135-
streamInfo.lastDecodedAvFramePts = avFrame->pts;
1153+
streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame);
11361154
streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
11371155

11381156
return avFrame;
@@ -1149,7 +1167,7 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
11491167
FrameOutput frameOutput;
11501168
auto& streamInfo = streamInfos_[activeStreamIndex_];
11511169
frameOutput.ptsSeconds = ptsToSeconds(
1152-
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
1170+
getPtsOrDts(avFrame), formatContext_->streams[activeStreamIndex_]->time_base);
11531171
frameOutput.durationSeconds = ptsToSeconds(
11541172
getDuration(avFrame),
11551173
formatContext_->streams[activeStreamIndex_]->time_base);

0 commit comments

Comments
 (0)