Skip to content

Commit 7176788

Browse files
committed
Add h265 support
1 parent f55dcc0 commit 7176788

File tree

7 files changed

+137
-32
lines changed

7 files changed

+137
-32
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
133133
return UniqueCUvideodecoder(decoder, CUvideoDecoderDeleter{});
134134
}
135135

136+
cudaVideoCodec validateCodecSupport(AVCodecID codecId) {
137+
switch (codecId) {
138+
case AV_CODEC_ID_H264:
139+
return cudaVideoCodec_H264;
140+
case AV_CODEC_ID_HEVC:
141+
return cudaVideoCodec_HEVC;
142+
// TODONVDEC P0: support more codecs
143+
// case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
144+
// case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
145+
// case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
146+
// case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
147+
// case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
148+
default: {
149+
TORCH_CHECK(false, "Unsupported codec type: ", avcodec_get_name(codecId));
150+
}
151+
}
152+
}
153+
136154
} // namespace
137155

138156
BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
@@ -158,29 +176,62 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
158176
}
159177
}
160178

161-
void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
162-
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
163-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
179+
void BetaCudaDeviceInterface::initializeBSF(
180+
const AVCodecParameters* codecPar,
181+
const UniqueDecodingAVFormatContext& avFormatCtx) {
182+
// Setup bit stream filters (BSF):
183+
// https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
184+
// This is only needed for some formats, like H264 or HEVC.
164185

165-
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
166-
timeBase_ = avStream->time_base;
186+
TORCH_CHECK(codecPar != nullptr, "codecPar cannot be null");
187+
TORCH_CHECK(avFormatCtx != nullptr, "AVFormatContext cannot be null");
188+
TORCH_CHECK(
189+
avFormatCtx->iformat != nullptr,
190+
"AVFormatContext->iformat cannot be null");
191+
std::string filterName;
192+
193+
// Matching logic is taken from DALI
194+
switch (codecPar->codec_id) {
195+
case AV_CODEC_ID_H264: {
196+
const std::string formatName = avFormatCtx->iformat->long_name
197+
? avFormatCtx->iformat->long_name
198+
: "";
199+
200+
if (formatName == "QuickTime / MOV" ||
201+
formatName == "FLV (Flash Video)" ||
202+
formatName == "Matroska / WebM" || formatName == "raw H.264 video") {
203+
filterName = "h264_mp4toannexb";
204+
}
205+
break;
206+
}
167207

168-
const AVCodecParameters* codecpar = avStream->codecpar;
169-
TORCH_CHECK(codecpar != nullptr, "CodecParameters cannot be null");
208+
case AV_CODEC_ID_HEVC: {
209+
const std::string formatName = avFormatCtx->iformat->long_name
210+
? avFormatCtx->iformat->long_name
211+
: "";
170212

171-
TORCH_CHECK(
172-
// TODONVDEC P0 support more
173-
avStream->codecpar->codec_id == AV_CODEC_ID_H264,
174-
"Can only do H264 for now");
213+
if (formatName == "QuickTime / MOV" ||
214+
formatName == "FLV (Flash Video)" ||
215+
formatName == "Matroska / WebM" || formatName == "raw HEVC video") {
216+
filterName = "hevc_mp4toannexb";
217+
}
218+
break;
219+
}
175220

176-
// Setup bit stream filters (BSF):
177-
// https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
178-
// This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
179-
// now we apply BSF unconditionally, but it should be optional and dependent
180-
// on codec and container.
181-
const AVBitStreamFilter* avBSF = av_bsf_get_by_name("h264_mp4toannexb");
221+
default:
222+
// No bitstream filter needed for other codecs
223+
// TODONVDEC P1 MPEG4 will need one!
224+
break;
225+
}
226+
227+
if (filterName.empty()) {
228+
// Only initialize BSF if we actually need one
229+
return;
230+
}
231+
232+
const AVBitStreamFilter* avBSF = av_bsf_get_by_name(filterName.c_str());
182233
TORCH_CHECK(
183-
avBSF != nullptr, "Failed to find h264_mp4toannexb bitstream filter");
234+
avBSF != nullptr, "Failed to find bitstream filter: ", filterName);
184235

185236
AVBSFContext* avBSFContext = nullptr;
186237
int retVal = av_bsf_alloc(avBSF, &avBSFContext);
@@ -191,7 +242,7 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
191242

192243
bitstreamFilter_.reset(avBSFContext);
193244

194-
retVal = avcodec_parameters_copy(bitstreamFilter_->par_in, codecpar);
245+
retVal = avcodec_parameters_copy(bitstreamFilter_->par_in, codecPar);
195246
TORCH_CHECK(
196247
retVal >= AVSUCCESS,
197248
"Failed to copy codec parameters: ",
@@ -202,10 +253,25 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
202253
retVal == AVSUCCESS,
203254
"Failed to initialize bitstream filter: ",
204255
getFFMPEGErrorStringFromErrorCode(retVal));
256+
}
257+
258+
void BetaCudaDeviceInterface::initializeInterface(
259+
const AVStream* avStream,
260+
const UniqueDecodingAVFormatContext& avFormatCtx) {
261+
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
262+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
263+
264+
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
265+
timeBase_ = avStream->time_base;
266+
267+
const AVCodecParameters* codecPar = avStream->codecpar;
268+
TORCH_CHECK(codecPar != nullptr, "CodecParameters cannot be null");
269+
270+
initializeBSF(codecPar, avFormatCtx);
205271

206272
// Create parser. Default values that aren't obvious are taken from DALI.
207273
CUVIDPARSERPARAMS parserParams = {};
208-
parserParams.CodecType = cudaVideoCodec_H264;
274+
parserParams.CodecType = validateCodecSupport(codecPar->codec_id);
209275
parserParams.ulMaxNumDecodeSurfaces = 8;
210276
parserParams.ulMaxDisplayDelay = 0;
211277
// Callback setup, all are triggered by the parser within a call

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
3737
explicit BetaCudaDeviceInterface(const torch::Device& device);
3838
virtual ~BetaCudaDeviceInterface();
3939

40-
void initializeInterface(AVStream* stream) override;
40+
void initializeInterface(
41+
const AVStream* stream,
42+
const UniqueDecodingAVFormatContext& avFormatCtx) override;
4143

4244
void convertAVFrameToFrameOutput(
4345
const VideoStreamOptions& videoStreamOptions,
@@ -63,6 +65,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
6365
private:
6466
// Apply bitstream filter, modifies packet in-place
6567
void applyBSF(ReferenceAVPacket& packet);
68+
void initializeBSF(
69+
const AVCodecParameters* codecPar,
70+
const UniqueDecodingAVFormatContext& avFormatCtx);
6671

6772
UniqueAVFrame convertCudaFrameToAVFrame(
6873
CUdeviceptr framePtr,
@@ -156,7 +161,7 @@ class BetaCudaDeviceInterface : public DeviceInterface {
156161
// frameReadyForDecoding()
157162
// cuvidDecodePicture() Send frame to NVDEC for async decoding
158163
//
159-
// receiveFrame() -> EAGAIN Frame is potentially already decoded
164+
// receiveFrame() -> EAGAIN Frame is potentially already decoded
160165
// and could be returned, but we don't
161166
// know because frameReadyInDisplayOrder
162167
// hasn't been triggered yet. We'll only

src/torchcodec/_core/DeviceInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class DeviceInterface {
5555
virtual void initializeContext(
5656
[[maybe_unused]] AVCodecContext* codecContext) {}
5757

58-
virtual void initializeInterface([[maybe_unused]] AVStream* stream) {}
58+
virtual void initializeInterface(
59+
[[maybe_unused]] const AVStream* stream,
60+
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) {}
5961

6062
virtual void convertAVFrameToFrameOutput(
6163
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ void SingleStreamDecoder::addStream(
462462
if (mediaType == AVMEDIA_TYPE_VIDEO) {
463463
if (deviceInterface_) {
464464
deviceInterface_->initializeContext(codecContext);
465-
deviceInterface_->initializeInterface(streamInfo.stream);
465+
deviceInterface_->initializeInterface(streamInfo.stream, formatContext_);
466466
}
467467
}
468468

test/resources/testsrc2_h265.mp4

890 KB
Binary file not shown.

test/test_decoders.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
SINE_MONO_S32_44100,
4545
SINE_MONO_S32_8000,
4646
TEST_SRC_2_720P,
47+
TEST_SRC_2_720P_H265,
4748
)
4849

4950

@@ -1413,7 +1414,9 @@ def test_get_frames_at_tensor_indices(self):
14131414
# assert_tensor_close_on_at_least or something like that.
14141415

14151416
@needs_cuda
1416-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1417+
@pytest.mark.parametrize(
1418+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1419+
)
14171420
@pytest.mark.parametrize("contiguous_indices", (True, False))
14181421
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14191422
def test_beta_cuda_interface_get_frame_at(
@@ -1443,7 +1446,9 @@ def test_beta_cuda_interface_get_frame_at(
14431446
assert beta_frame.duration_seconds == ref_frame.duration_seconds
14441447

14451448
@needs_cuda
1446-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1449+
@pytest.mark.parametrize(
1450+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1451+
)
14471452
@pytest.mark.parametrize("contiguous_indices", (True, False))
14481453
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14491454
def test_beta_cuda_interface_get_frames_at(
@@ -1474,7 +1479,9 @@ def test_beta_cuda_interface_get_frames_at(
14741479
)
14751480

14761481
@needs_cuda
1477-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1482+
@pytest.mark.parametrize(
1483+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1484+
)
14781485
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14791486
def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
14801487
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
@@ -1496,7 +1503,9 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
14961503
assert beta_frame.duration_seconds == ref_frame.duration_seconds
14971504

14981505
@needs_cuda
1499-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1506+
@pytest.mark.parametrize(
1507+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1508+
)
15001509
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15011510
def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
15021511
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
@@ -1519,7 +1528,9 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
15191528
)
15201529

15211530
@needs_cuda
1522-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1531+
@pytest.mark.parametrize(
1532+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1533+
)
15231534
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15241535
def test_beta_cuda_interface_backwards(self, asset, seek_mode):
15251536

@@ -1539,12 +1550,24 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode):
15391550
assert beta_frame.pts_seconds == ref_frame.pts_seconds
15401551
assert beta_frame.duration_seconds == ref_frame.duration_seconds
15411552

1553+
@needs_cuda
1554+
def test_beta_cuda_interface_small_h265(self):
1555+
# TODONVDEC P2 investigate why/how the default interface can decode this
1556+
# video.
1557+
1558+
# This is fine on the default interface - why?
1559+
VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0)
1560+
# But it fails on the beta interface due to input validation checks, which we took from DALI!
1561+
with pytest.raises(
1562+
RuntimeError,
1563+
match="Video is too small in at least one dimension. Provided: 128x128 vs supported:144x144",
1564+
):
1565+
VideoDecoder(H265_VIDEO.path, device="cuda:0:beta").get_frame_at(0)
1566+
15421567
@needs_cuda
15431568
def test_beta_cuda_interface_error(self):
1544-
with pytest.raises(RuntimeError, match="Can only do H264 for now"):
1569+
with pytest.raises(RuntimeError, match="Unsupported codec type: av1"):
15451570
VideoDecoder(AV1_VIDEO.path, device="cuda:0:beta")
1546-
with pytest.raises(RuntimeError, match="Can only do H264 for now"):
1547-
VideoDecoder(H265_VIDEO.path, device="cuda:0:beta")
15481571
with pytest.raises(RuntimeError, match="Unsupported device"):
15491572
VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant")
15501573

test/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,3 +688,12 @@ def sample_format(self) -> str:
688688
},
689689
frames={0: {}}, # Not needed for now
690690
)
691+
# ffmpeg -f lavfi -i testsrc2=duration=10:size=1280x720:rate=30 -c:v libx265 -crf 23 -preset medium output.mp4
692+
TEST_SRC_2_720P_H265 = TestVideo(
693+
filename="testsrc2_h265.mp4",
694+
default_stream_index=0,
695+
stream_infos={
696+
0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
697+
},
698+
frames={0: {}}, # Not needed for now
699+
)

0 commit comments

Comments
 (0)