@@ -133,6 +133,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
133
133
return UniqueCUvideodecoder (decoder, CUvideoDecoderDeleter{});
134
134
}
135
135
136
+ cudaVideoCodec validateCodecSupport (AVCodecID codecId) {
137
+ switch (codecId) {
138
+ case AV_CODEC_ID_H264:
139
+ return cudaVideoCodec_H264;
140
+ case AV_CODEC_ID_HEVC:
141
+ return cudaVideoCodec_HEVC;
142
+ // TODONVDEC P0: support more codecs
143
+ // case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
144
+ // case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
145
+ // case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
146
+ // case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
147
+ // case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
148
+ default : {
149
+ TORCH_CHECK (false , " Unsupported codec type: " , avcodec_get_name (codecId));
150
+ }
151
+ }
152
+ }
153
+
136
154
} // namespace
137
155
138
156
BetaCudaDeviceInterface::BetaCudaDeviceInterface (const torch::Device& device)
@@ -158,29 +176,62 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
158
176
}
159
177
}
160
178
161
- void BetaCudaDeviceInterface::initializeInterface (AVStream* avStream) {
162
- torch::Tensor dummyTensorForCudaInitialization = torch::empty (
163
- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
179
+ void BetaCudaDeviceInterface::initializeBSF (
180
+ const AVCodecParameters* codecPar,
181
+ const UniqueDecodingAVFormatContext& avFormatCtx) {
182
+ // Setup bit stream filters (BSF):
183
+ // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
184
+ // This is only needed for some formats, like H264 or HEVC.
164
185
165
- TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
166
- timeBase_ = avStream->time_base ;
186
+ TORCH_CHECK (codecPar != nullptr , " codecPar cannot be null" );
187
+ TORCH_CHECK (avFormatCtx != nullptr , " AVFormatContext cannot be null" );
188
+ TORCH_CHECK (
189
+ avFormatCtx->iformat != nullptr ,
190
+ " AVFormatContext->iformat cannot be null" );
191
+ std::string filterName;
192
+
193
+ // Matching logic is taken from DALI
194
+ switch (codecPar->codec_id ) {
195
+ case AV_CODEC_ID_H264: {
196
+ const std::string formatName = avFormatCtx->iformat ->long_name
197
+ ? avFormatCtx->iformat ->long_name
198
+ : " " ;
199
+
200
+ if (formatName == " QuickTime / MOV" ||
201
+ formatName == " FLV (Flash Video)" ||
202
+ formatName == " Matroska / WebM" || formatName == " raw H.264 video" ) {
203
+ filterName = " h264_mp4toannexb" ;
204
+ }
205
+ break ;
206
+ }
167
207
168
- const AVCodecParameters* codecpar = avStream->codecpar ;
169
- TORCH_CHECK (codecpar != nullptr , " CodecParameters cannot be null" );
208
+ case AV_CODEC_ID_HEVC: {
209
+ const std::string formatName = avFormatCtx->iformat ->long_name
210
+ ? avFormatCtx->iformat ->long_name
211
+ : " " ;
170
212
171
- TORCH_CHECK (
172
- // TODONVDEC P0 support more
173
- avStream->codecpar ->codec_id == AV_CODEC_ID_H264,
174
- " Can only do H264 for now" );
213
+ if (formatName == " QuickTime / MOV" ||
214
+ formatName == " FLV (Flash Video)" ||
215
+ formatName == " Matroska / WebM" || formatName == " raw HEVC video" ) {
216
+ filterName = " hevc_mp4toannexb" ;
217
+ }
218
+ break ;
219
+ }
175
220
176
- // Setup bit stream filters (BSF):
177
- // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
178
- // This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
179
- // now we apply BSF unconditionally, but it should be optional and dependent
180
- // on codec and container.
181
- const AVBitStreamFilter* avBSF = av_bsf_get_by_name (" h264_mp4toannexb" );
221
+ default :
222
+ // No bitstream filter needed for other codecs
223
+ // TODONVDEC P1 MPEG4 will need one!
224
+ break ;
225
+ }
226
+
227
+ if (filterName.empty ()) {
228
+ // Only initialize BSF if we actually need one
229
+ return ;
230
+ }
231
+
232
+ const AVBitStreamFilter* avBSF = av_bsf_get_by_name (filterName.c_str ());
182
233
TORCH_CHECK (
183
- avBSF != nullptr , " Failed to find h264_mp4toannexb bitstream filter" );
234
+ avBSF != nullptr , " Failed to find bitstream filter: " , filterName );
184
235
185
236
AVBSFContext* avBSFContext = nullptr ;
186
237
int retVal = av_bsf_alloc (avBSF, &avBSFContext);
@@ -191,7 +242,7 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
191
242
192
243
bitstreamFilter_.reset (avBSFContext);
193
244
194
- retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecpar );
245
+ retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecPar );
195
246
TORCH_CHECK (
196
247
retVal >= AVSUCCESS,
197
248
" Failed to copy codec parameters: " ,
@@ -202,10 +253,25 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
202
253
retVal == AVSUCCESS,
203
254
" Failed to initialize bitstream filter: " ,
204
255
getFFMPEGErrorStringFromErrorCode (retVal));
256
+ }
257
+
258
+ void BetaCudaDeviceInterface::initializeInterface (
259
+ const AVStream* avStream,
260
+ const UniqueDecodingAVFormatContext& avFormatCtx) {
261
+ torch::Tensor dummyTensorForCudaInitialization = torch::empty (
262
+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
263
+
264
+ TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
265
+ timeBase_ = avStream->time_base ;
266
+
267
+ const AVCodecParameters* codecPar = avStream->codecpar ;
268
+ TORCH_CHECK (codecPar != nullptr , " CodecParameters cannot be null" );
269
+
270
+ initializeBSF (codecPar, avFormatCtx);
205
271
206
272
// Create parser. Default values that aren't obvious are taken from DALI.
207
273
CUVIDPARSERPARAMS parserParams = {};
208
- parserParams.CodecType = cudaVideoCodec_H264 ;
274
+ parserParams.CodecType = validateCodecSupport (codecPar-> codec_id ) ;
209
275
parserParams.ulMaxNumDecodeSurfaces = 8 ;
210
276
parserParams.ulMaxDisplayDelay = 0 ;
211
277
// Callback setup, all are triggered by the parser within a call
0 commit comments