Skip to content

Commit a6a68bb

Browse files
committed
Add AV1 support
1 parent d6a2cfa commit a6a68bb

File tree

5 files changed

+46
-20
lines changed

5 files changed

+46
-20
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,9 @@ cudaVideoCodec validateCodecSupport(AVCodecID codecId) {
144144
return cudaVideoCodec_H264;
145145
case AV_CODEC_ID_HEVC:
146146
return cudaVideoCodec_HEVC;
147+
case AV_CODEC_ID_AV1:
148+
return cudaVideoCodec_AV1;
147149
// TODONVDEC P0: support more codecs
148-
// case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
149150
// case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
150151
// case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
151152
// case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
@@ -268,6 +269,7 @@ void BetaCudaDeviceInterface::initializeInterface(
268269

269270
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
270271
timeBase_ = avStream->time_base;
272+
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
271273

272274
const AVCodecParameters* codecPar = avStream->codecpar;
273275
TORCH_CHECK(codecPar != nullptr, "CodecParameters cannot be null");
@@ -484,14 +486,19 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
484486
avFrame->format = AV_PIX_FMT_CUDA;
485487
avFrame->pts = dispInfo.timestamp;
486488

487-
// TODONVDEC P0: Zero division error!!!
488-
// TODONVDEC P0: Move AVRational arithmetic to FFMPEGCommon, and put the
489-
// similar SingleStreamDecoder stuff there too.
490-
unsigned int frameRateNum = videoFormat_.frame_rate.numerator;
491-
unsigned int frameRateDen = videoFormat_.frame_rate.denominator;
492-
int64_t duration = static_cast<int64_t>((frameRateDen * timeBase_.den)) /
493-
(frameRateNum * timeBase_.num);
494-
setDuration(avFrame, duration);
489+
// TODONVDEC P2: We compute the duration based on average frame rate info:
490+
// either from NVCUVID if it's valid, otherwise from FFmpeg as fallback. But
491+
// both of these are based on average frame rate, so if the video has
492+
// variable frame rate, the durations may be off. We should try to see if we
493+
// can set the duration more accurately. Unfortunately it's not given by
494+
// dispInfo. One option would be to set it based on the pts difference between
495+
// consecutive frames, if the next frame is already available.
496+
int frameRateNum = static_cast<int>(videoFormat_.frame_rate.numerator);
497+
int frameRateDen = static_cast<int>(videoFormat_.frame_rate.denominator);
498+
AVRational frameRate = (frameRateNum > 0 && frameRateDen > 0)
499+
? AVRational{frameRateNum, frameRateDen}
500+
: frameRateAvgFromFFmpeg_;
501+
setDuration(avFrame, computeSafeDuration(frameRate, timeBase_));
495502

496503
// We need to assign the frame colorspace. This is crucial for proper color
497504
// conversion. NVCUVID stores that in the matrix_coefficients field, but

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
8686
// isFlushing_)
8787
bool isFlushing_ = false;
8888

89-
AVRational timeBase_ = {0, 0};
89+
AVRational timeBase_ = {0, 1};
90+
AVRational frameRateAvgFromFFmpeg_ = {0, 1};
9091

9192
UniqueAVBSFContext bitstreamFilter_;
9293

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,4 +501,26 @@ AVIOContext* avioAllocContext(
501501
seek);
502502
}
503503

504+
double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
505+
// To perform the multiplication before the division, av_q2d is not used
506+
return static_cast<double>(pts) * timeBase.num / timeBase.den;
507+
}
508+
509+
int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
510+
return static_cast<int64_t>(
511+
std::round(seconds * timeBase.den / timeBase.num));
512+
}
513+
514+
int64_t computeSafeDuration(
515+
const AVRational& frameRate,
516+
const AVRational& timeBase) {
517+
if (frameRate.num <= 0 || frameRate.den <= 0 || timeBase.num <= 0 ||
518+
timeBase.den <= 0) {
519+
return 0;
520+
} else {
521+
return (static_cast<int64_t>(frameRate.den) * timeBase.den) /
522+
(static_cast<int64_t>(timeBase.num) * frameRate.num);
523+
}
524+
}
525+
504526
} // namespace facebook::torchcodec

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,4 +232,10 @@ AVIOContext* avioAllocContext(
232232
AVIOWriteFunction write_packet,
233233
AVIOSeekFunction seek);
234234

235+
double ptsToSeconds(int64_t pts, const AVRational& timeBase);
236+
int64_t secondsToClosestPts(double seconds, const AVRational& timeBase);
237+
int64_t computeSafeDuration(
238+
const AVRational& frameRate,
239+
const AVRational& timeBase);
240+
235241
} // namespace facebook::torchcodec

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,6 @@
1717
namespace facebook::torchcodec {
1818
namespace {
1919

20-
double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
21-
// To perform the multiplication before the division, av_q2d is not used
22-
return static_cast<double>(pts) * timeBase.num / timeBase.den;
23-
}
24-
25-
int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
26-
return static_cast<int64_t>(
27-
std::round(seconds * timeBase.den / timeBase.num));
28-
}
29-
3020
// Some videos aren't properly encoded and do not specify pts values for
3121
// packets, and thus for frames. Unset values correspond to INT64_MIN. When that
3222
// happens, we fallback to the dts value which hopefully exists and is correct.

0 commit comments

Comments
 (0)