Skip to content

Commit 8ad66ce

Browse files
committed
Add AV1 support
1 parent 7ea3ca9 commit 8ad66ce

File tree

6 files changed

+107
-28
lines changed

6 files changed

+107
-28
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,9 @@ cudaVideoCodec validateCodecSupport(AVCodecID codecId) {
144144
return cudaVideoCodec_H264;
145145
case AV_CODEC_ID_HEVC:
146146
return cudaVideoCodec_HEVC;
147+
case AV_CODEC_ID_AV1:
148+
return cudaVideoCodec_AV1;
147149
// TODONVDEC P0: support more codecs
148-
// case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
149150
// case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
150151
// case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
151152
// case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
@@ -268,6 +269,7 @@ void BetaCudaDeviceInterface::initializeInterface(
268269

269270
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
270271
timeBase_ = avStream->time_base;
272+
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
271273

272274
const AVCodecParameters* codecPar = avStream->codecpar;
273275
TORCH_CHECK(codecPar != nullptr, "CodecParameters cannot be null");
@@ -484,14 +486,19 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
484486
avFrame->format = AV_PIX_FMT_CUDA;
485487
avFrame->pts = dispInfo.timestamp;
486488

487-
// TODONVDEC P0: Zero division error!!!
488-
// TODONVDEC P0: Move AVRational arithmetic to FFMPEGCommon, and put the
489-
// similar SingleStreamDecoder stuff there too.
490-
unsigned int frameRateNum = videoFormat_.frame_rate.numerator;
491-
unsigned int frameRateDen = videoFormat_.frame_rate.denominator;
492-
int64_t duration = static_cast<int64_t>((frameRateDen * timeBase_.den)) /
493-
(frameRateNum * timeBase_.num);
494-
setDuration(avFrame, duration);
489+
// TODONVDEC P2: We compute the duration based on average frame rate info:
490+
// either from NVCUVID if it's valid, otherwise from FFmpeg as fallback. But
491+
// both of these are based on average frame rate, so if the video has
492+
// variable frame rate, the durations may be off. We should try to see if we
493+
// can set the duration more accurately. Unfortunately it's not given by
494+
// dispInfo. One option would be to set it based on the pts difference between
495+
// consecutive frames, if the next frame is already available.
496+
int frameRateNum = static_cast<int>(videoFormat_.frame_rate.numerator);
497+
int frameRateDen = static_cast<int>(videoFormat_.frame_rate.denominator);
498+
AVRational frameRate = (frameRateNum > 0 && frameRateDen > 0)
499+
? AVRational{frameRateNum, frameRateDen}
500+
: frameRateAvgFromFFmpeg_;
501+
setDuration(avFrame, computeSafeDuration(frameRate, timeBase_));
495502

496503
// We need to assign the frame colorspace. This is crucial for proper color
497504
// conversion. NVCUVID stores that in the matrix_coefficients field, but

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
8686
// isFlushing_)
8787
bool isFlushing_ = false;
8888

89-
AVRational timeBase_ = {0, 0};
89+
AVRational timeBase_ = {0, 1};
90+
AVRational frameRateAvgFromFFmpeg_ = {0, 1};
9091

9192
UniqueAVBSFContext bitstreamFilter_;
9293

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,4 +501,26 @@ AVIOContext* avioAllocContext(
501501
seek);
502502
}
503503

504+
double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
505+
// To perform the multiplication before the division, av_q2d is not used
506+
return static_cast<double>(pts) * timeBase.num / timeBase.den;
507+
}
508+
509+
int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
510+
return static_cast<int64_t>(
511+
std::round(seconds * timeBase.den / timeBase.num));
512+
}
513+
514+
int64_t computeSafeDuration(
515+
const AVRational& frameRate,
516+
const AVRational& timeBase) {
517+
if (frameRate.num <= 0 || frameRate.den <= 0 || timeBase.num <= 0 ||
518+
timeBase.den <= 0) {
519+
return 0;
520+
} else {
521+
return (static_cast<int64_t>(frameRate.den) * timeBase.den) /
522+
(static_cast<int64_t>(timeBase.num) * frameRate.num);
523+
}
524+
}
525+
504526
} // namespace facebook::torchcodec

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,4 +232,10 @@ AVIOContext* avioAllocContext(
232232
AVIOWriteFunction write_packet,
233233
AVIOSeekFunction seek);
234234

235+
double ptsToSeconds(int64_t pts, const AVRational& timeBase);
236+
int64_t secondsToClosestPts(double seconds, const AVRational& timeBase);
237+
int64_t computeSafeDuration(
238+
const AVRational& frameRate,
239+
const AVRational& timeBase);
240+
235241
} // namespace facebook::torchcodec

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,6 @@
1717
namespace facebook::torchcodec {
1818
namespace {
1919

20-
double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
21-
// To perform the multiplication before the division, av_q2d is not used
22-
return static_cast<double>(pts) * timeBase.num / timeBase.den;
23-
}
24-
25-
int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
26-
return static_cast<int64_t>(
27-
std::round(seconds * timeBase.den / timeBase.num));
28-
}
29-
3020
// Some videos aren't properly encoded and do not specify pts values for
3121
// packets, and thus for frames. Unset values correspond to INT64_MIN. When that
3222
// happens, we fallback to the dts value which hopefully exists and is correct.

test/test_decoders.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,13 +1417,23 @@ def test_get_frames_at_tensor_indices(self):
14171417

14181418
@needs_cuda
14191419
@pytest.mark.parametrize(
1420-
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1420+
"asset",
1421+
(
1422+
NASA_VIDEO,
1423+
TEST_SRC_2_720P,
1424+
BT709_FULL_RANGE,
1425+
TEST_SRC_2_720P_H265,
1426+
AV1_VIDEO,
1427+
),
14211428
)
14221429
@pytest.mark.parametrize("contiguous_indices", (True, False))
14231430
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14241431
def test_beta_cuda_interface_get_frame_at(
14251432
self, asset, contiguous_indices, seek_mode
14261433
):
1434+
if asset == AV1_VIDEO and seek_mode == "approximate":
1435+
pytest.skip("AV1 asset doesn't work with approximate mode")
1436+
14271437
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
14281438
beta_decoder = VideoDecoder(
14291439
asset.path, device="cuda:0:beta", seek_mode=seek_mode
@@ -1449,13 +1459,23 @@ def test_beta_cuda_interface_get_frame_at(
14491459

14501460
@needs_cuda
14511461
@pytest.mark.parametrize(
1452-
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1462+
"asset",
1463+
(
1464+
NASA_VIDEO,
1465+
TEST_SRC_2_720P,
1466+
BT709_FULL_RANGE,
1467+
TEST_SRC_2_720P_H265,
1468+
AV1_VIDEO,
1469+
),
14531470
)
14541471
@pytest.mark.parametrize("contiguous_indices", (True, False))
14551472
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14561473
def test_beta_cuda_interface_get_frames_at(
14571474
self, asset, contiguous_indices, seek_mode
14581475
):
1476+
if asset == AV1_VIDEO and seek_mode == "approximate":
1477+
pytest.skip("AV1 asset doesn't work with approximate mode")
1478+
14591479
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
14601480
beta_decoder = VideoDecoder(
14611481
asset.path, device="cuda:0:beta", seek_mode=seek_mode
@@ -1482,10 +1502,20 @@ def test_beta_cuda_interface_get_frames_at(
14821502

14831503
@needs_cuda
14841504
@pytest.mark.parametrize(
1485-
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1505+
"asset",
1506+
(
1507+
NASA_VIDEO,
1508+
TEST_SRC_2_720P,
1509+
BT709_FULL_RANGE,
1510+
TEST_SRC_2_720P_H265,
1511+
AV1_VIDEO,
1512+
),
14861513
)
14871514
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14881515
def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
1516+
if asset == AV1_VIDEO and seek_mode == "approximate":
1517+
pytest.skip("AV1 asset doesn't work with approximate mode")
1518+
14891519
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
14901520
beta_decoder = VideoDecoder(
14911521
asset.path, device="cuda:0:beta", seek_mode=seek_mode
@@ -1506,10 +1536,20 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
15061536

15071537
@needs_cuda
15081538
@pytest.mark.parametrize(
1509-
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1539+
"asset",
1540+
(
1541+
NASA_VIDEO,
1542+
TEST_SRC_2_720P,
1543+
BT709_FULL_RANGE,
1544+
TEST_SRC_2_720P_H265,
1545+
AV1_VIDEO,
1546+
),
15101547
)
15111548
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15121549
def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
1550+
if asset == AV1_VIDEO and seek_mode == "approximate":
1551+
pytest.skip("AV1 asset doesn't work with approximate mode")
1552+
15131553
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
15141554
beta_decoder = VideoDecoder(
15151555
asset.path, device="cuda:0:beta", seek_mode=seek_mode
@@ -1531,10 +1571,19 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
15311571

15321572
@needs_cuda
15331573
@pytest.mark.parametrize(
1534-
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1574+
"asset",
1575+
(
1576+
NASA_VIDEO,
1577+
TEST_SRC_2_720P,
1578+
BT709_FULL_RANGE,
1579+
TEST_SRC_2_720P_H265,
1580+
AV1_VIDEO,
1581+
),
15351582
)
15361583
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15371584
def test_beta_cuda_interface_backwards(self, asset, seek_mode):
1585+
if asset == AV1_VIDEO and seek_mode == "approximate":
1586+
pytest.skip("AV1 asset doesn't work with approximate mode")
15381587

15391588
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
15401589
beta_decoder = VideoDecoder(
@@ -1543,8 +1592,14 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode):
15431592

15441593
assert ref_decoder.metadata == beta_decoder.metadata
15451594

1546-
for frame_index in [0, 100, 10, 50, 20, 200, 150, 389]:
1595+
for frame_index in [0, 1, 2, 1, 0, 100, 10, 50, 20, 200, 150, 150, 150, 389, 2]:
1596+
# This is ugly, but OK: the indices values above are relevant for
1597+
# the NASA_VIDEO. We need to avoid going out of bounds for other
1598+
# videos so we cap the frame_index. This test still serves its
1599+
# purpose: no matter what the range of the video, we're still doing
1600+
# backwards seeks.
15471601
frame_index = min(frame_index, len(ref_decoder) - 1)
1602+
15481603
ref_frame = ref_decoder.get_frame_at(frame_index)
15491604
beta_frame = beta_decoder.get_frame_at(frame_index)
15501605
torch.testing.assert_close(beta_frame.data, ref_frame.data, rtol=0, atol=0)
@@ -1568,8 +1623,6 @@ def test_beta_cuda_interface_small_h265(self):
15681623

15691624
@needs_cuda
15701625
def test_beta_cuda_interface_error(self):
1571-
with pytest.raises(RuntimeError, match="Unsupported codec type: av1"):
1572-
VideoDecoder(AV1_VIDEO.path, device="cuda:0:beta")
15731626
with pytest.raises(RuntimeError, match="Unsupported device"):
15741627
VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant")
15751628

0 commit comments

Comments
 (0)