@@ -425,6 +425,9 @@ void VideoDecoder::addStream(
425425 TORCH_CHECK (
426426 activeStreamIndex_ == NO_ACTIVE_STREAM,
427427 " Can only add one single stream." );
428+ TORCH_CHECK (
429+ mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO,
430+ " Can only add video or audio streams." );
428431 TORCH_CHECK (formatContext_.get () != nullptr );
429432
430433 AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
@@ -448,9 +451,10 @@ void VideoDecoder::addStream(
448451
449452 // This should never happen, checking just to be safe.
450453 TORCH_CHECK (
451- streamInfo.stream ->codecpar ->codec_type == mediaType,
452- " FFmpeg found stream with index " , activeStreamIndex_, " which is of the wrong media type." );
453-
454+ streamInfo.stream ->codecpar ->codec_type == mediaType,
455+ " FFmpeg found stream with index " ,
456+ activeStreamIndex_,
457+ " which is of the wrong media type." );
454458
455459 if (mediaType == AVMEDIA_TYPE_VIDEO &&
456460 videoStreamOptions.device .type () == torch::kCUDA ) {
@@ -1076,8 +1080,10 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10761080 avFrame->pts , formatContext_->streams [streamIndex]->time_base );
10771081 frameOutput.durationSeconds = ptsToSeconds (
10781082 getDuration (avFrame), formatContext_->streams [streamIndex]->time_base );
1079- // TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1080- if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1083+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1084+ // TODO: handle preAllocatedTensor for audio
1085+ convertAudioAVFrameToFrameOutputOnCPU (avFrameStream, frameOutput);
1086+ } else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
10811087 convertAVFrameToFrameOutputOnCPU (
10821088 avFrameStream, frameOutput, preAllocatedOutputTensor);
10831089 } else if (streamInfo.videoStreamOptions .device .type () == torch::kCUDA ) {
@@ -1253,6 +1259,39 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
12531259 filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
12541260}
12551261
1262+ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1263+ VideoDecoder::AVFrameStream& avFrameStream,
1264+ FrameOutput& frameOutput) {
1265+ AVFrame* avFrame = avFrameStream.avFrame .get ();
1266+
1267+ auto numSamples = avFrame->nb_samples ; // per channel
1268+ auto numChannels =
1269+ avFrame->ch_layout .nb_channels ; // TODO handle other ffmpeg versions
1270+
1271+ // TODO: dtype should be format-dependent
1272+ torch::Tensor data = torch::empty ({numChannels, numSamples}, torch::kFloat32 );
1273+
1274+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1275+ // TODO Implement all formats
1276+ switch (format) {
1277+ case AV_SAMPLE_FMT_FLTP: {
1278+ uint8_t * pData = static_cast <uint8_t *>(data.data_ptr ());
1279+ for (auto channel = 0 ; channel < numChannels; ++channel) {
1280+ auto numBytesToCopy = numSamples * av_get_bytes_per_sample (format);
1281+ memcpy (pData, avFrame->extended_data [channel], numBytesToCopy);
1282+ pData += numBytesToCopy;
1283+ }
1284+ break ;
1285+ }
1286+ default :
1287+ TORCH_CHECK (
1288+ false ,
1289+ " Unsupported audio format (yet!): " ,
1290+ av_get_sample_fmt_name (format));
1291+ }
1292+ frameOutput.data = data;
1293+ }
1294+
12561295// --------------------------------------------------------------------------
12571296// OUTPUT ALLOCATION AND SHAPE CONVERSION
12581297// --------------------------------------------------------------------------
@@ -1298,6 +1337,10 @@ torch::Tensor allocateEmptyHWCTensor(
12981337// Calling permute() is guaranteed to return a view as per the docs:
12991338// https://pytorch.org/docs/stable/generated/torch.permute.html
13001339torch::Tensor VideoDecoder::maybePermuteHWC2CHW (torch::Tensor& hwcTensor) {
1340+ if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_AUDIO) {
1341+ // TODO: Is this really how we want to handle audio?
1342+ return hwcTensor;
1343+ }
13011344 if (streamInfos_[activeStreamIndex_].videoStreamOptions .dimensionOrder ==
13021345 " NHWC" ) {
13031346 return hwcTensor;
0 commit comments