Skip to content

Commit 79954f9

Browse files
authored
Audio metadata support (#535)
1 parent da9164e commit 79954f9

16 files changed

+584
-245
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from pathlib import Path
8+
from typing import Optional, Union
9+
10+
from torch import Tensor
11+
12+
from torchcodec.decoders import _core as core
13+
from torchcodec.decoders._decoder_utils import (
14+
create_decoder,
15+
get_and_validate_stream_metadata,
16+
)
17+
18+
19+
class AudioDecoder:
20+
"""TODO-AUDIO docs"""
21+
22+
def __init__(
23+
self,
24+
source: Union[str, Path, bytes, Tensor],
25+
*,
26+
stream_index: Optional[int] = None,
27+
):
28+
self._decoder = create_decoder(source=source, seek_mode="approximate")
29+
30+
core.add_audio_stream(self._decoder, stream_index=stream_index)
31+
32+
(
33+
self.metadata,
34+
self.stream_index,
35+
self._begin_stream_seconds,
36+
self._end_stream_seconds,
37+
) = get_and_validate_stream_metadata(
38+
decoder=self._decoder, stream_index=stream_index, media_type="audio"
39+
)

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63+
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) {
64+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
65+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
66+
int numChannels = avCodecContext->ch_layout.nb_channels;
67+
#else
68+
int numChannels = avCodecContext->channels;
69+
#endif
70+
71+
return static_cast<int64_t>(numChannels);
72+
}
73+
6374
AVIOBytesContext::AVIOBytesContext(
6475
const void* data,
6576
size_t dataSize,

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139139
int64_t getDuration(const UniqueAVFrame& frame);
140140
int64_t getDuration(const AVFrame* frame);
141141

142+
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext);
143+
142144
// Returns true if sws_scale can handle unaligned data.
143145
bool canSwsScaleHandleUnalignedData();
144146

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 102 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,11 @@ void VideoDecoder::initializeDecoder() {
162162
av_q2d(avStream->time_base) * avStream->duration;
163163
}
164164

165-
double fps = av_q2d(avStream->r_frame_rate);
166-
if (fps > 0) {
167-
streamMetadata.averageFps = fps;
168-
}
169-
170165
if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
166+
double fps = av_q2d(avStream->r_frame_rate);
167+
if (fps > 0) {
168+
streamMetadata.averageFps = fps;
169+
}
171170
containerMetadata_.numVideoStreams++;
172171
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
173172
containerMetadata_.numAudioStreams++;
@@ -340,7 +339,7 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
340339
}
341340

342341
torch::Tensor VideoDecoder::getKeyFrameIndices() {
343-
validateActiveStream();
342+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
344343
validateScannedAllStreams("getKeyFrameIndices");
345344

346345
const std::vector<FrameInfo>& keyFrames =
@@ -409,84 +408,76 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions(
409408
}
410409
}
411410

412-
void VideoDecoder::addVideoStream(
411+
void VideoDecoder::addStream(
413412
int streamIndex,
414-
const VideoStreamOptions& videoStreamOptions) {
413+
AVMediaType mediaType,
414+
const torch::Device& device,
415+
std::optional<int> ffmpegThreadCount) {
415416
TORCH_CHECK(
416417
activeStreamIndex_ == NO_ACTIVE_STREAM,
417418
"Can only add one single stream.");
419+
TORCH_CHECK(
420+
mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO,
421+
"Can only add video or audio streams.");
418422
TORCH_CHECK(formatContext_.get() != nullptr);
419423

420424
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
421425

422426
activeStreamIndex_ = av_find_best_stream(
423-
formatContext_.get(), AVMEDIA_TYPE_VIDEO, streamIndex, -1, &avCodec, 0);
427+
formatContext_.get(), mediaType, streamIndex, -1, &avCodec, 0);
428+
424429
if (activeStreamIndex_ < 0) {
425-
throw std::invalid_argument("No valid stream found in input file.");
430+
throw std::invalid_argument(
431+
"No valid stream found in input file. Is " +
432+
std::to_string(streamIndex) + " of the desired media type?");
426433
}
434+
427435
TORCH_CHECK(avCodec != nullptr);
428436

429437
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
430438
streamInfo.streamIndex = activeStreamIndex_;
431439
streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base;
432440
streamInfo.stream = formatContext_->streams[activeStreamIndex_];
441+
streamInfo.avMediaType = mediaType;
433442

434-
if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) {
435-
throw std::invalid_argument(
436-
"Stream with index " + std::to_string(activeStreamIndex_) +
437-
" is not a video stream.");
438-
}
439-
440-
if (videoStreamOptions.device.type() == torch::kCUDA) {
443+
// This should never happen, checking just to be safe.
444+
TORCH_CHECK(
445+
streamInfo.stream->codecpar->codec_type == mediaType,
446+
"FFmpeg found stream with index ",
447+
activeStreamIndex_,
448+
" which is of the wrong media type.");
449+
450+
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
451+
// addStream() which is supposed to be generic
452+
if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) {
441453
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
442-
findCudaCodec(
443-
videoStreamOptions.device, streamInfo.stream->codecpar->codec_id)
454+
findCudaCodec(device, streamInfo.stream->codecpar->codec_id)
444455
.value_or(avCodec));
445456
}
446457

447-
StreamMetadata& streamMetadata =
448-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
449-
if (seekMode_ == SeekMode::approximate &&
450-
!streamMetadata.averageFps.has_value()) {
451-
throw std::runtime_error(
452-
"Seek mode is approximate, but stream " +
453-
std::to_string(activeStreamIndex_) +
454-
" does not have an average fps in its metadata.");
455-
}
456-
457458
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
458459
TORCH_CHECK(codecContext != nullptr);
459-
codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0);
460460
streamInfo.codecContext.reset(codecContext);
461461

462462
int retVal = avcodec_parameters_to_context(
463463
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
464464
TORCH_CHECK_EQ(retVal, AVSUCCESS);
465465

466-
if (videoStreamOptions.device.type() == torch::kCPU) {
467-
// No more initialization needed for CPU.
468-
} else if (videoStreamOptions.device.type() == torch::kCUDA) {
469-
initializeContextOnCuda(videoStreamOptions.device, codecContext);
470-
} else {
471-
TORCH_CHECK(
472-
false, "Invalid device type: " + videoStreamOptions.device.str());
466+
streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0);
467+
468+
// TODO_CODE_QUALITY same as above.
469+
if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) {
470+
initializeContextOnCuda(device, codecContext);
473471
}
474-
streamInfo.videoStreamOptions = videoStreamOptions;
475472

476473
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
477474
if (retVal < AVSUCCESS) {
478475
throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal));
479476
}
480477

481478
codecContext->time_base = streamInfo.stream->time_base;
482-
483-
containerMetadata_.allStreamMetadata[activeStreamIndex_].width =
484-
codecContext->width;
485-
containerMetadata_.allStreamMetadata[activeStreamIndex_].height =
486-
codecContext->height;
487-
auto codedId = codecContext->codec_id;
488479
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
489-
std::string(avcodec_get_name(codedId));
480+
std::string(avcodec_get_name(codecContext->codec_id));
490481

491482
// We will only need packets from the active stream, so we tell FFmpeg to
492483
// discard packets from the other streams. Note that av_read_frame() may still
@@ -497,6 +488,38 @@ void VideoDecoder::addVideoStream(
497488
formatContext_->streams[i]->discard = AVDISCARD_ALL;
498489
}
499490
}
491+
}
492+
493+
void VideoDecoder::addVideoStream(
494+
int streamIndex,
495+
const VideoStreamOptions& videoStreamOptions) {
496+
TORCH_CHECK(
497+
videoStreamOptions.device.type() == torch::kCPU ||
498+
videoStreamOptions.device.type() == torch::kCUDA,
499+
"Invalid device type: " + videoStreamOptions.device.str());
500+
501+
addStream(
502+
streamIndex,
503+
AVMEDIA_TYPE_VIDEO,
504+
videoStreamOptions.device,
505+
videoStreamOptions.ffmpegThreadCount);
506+
507+
auto& streamMetadata =
508+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
509+
510+
if (seekMode_ == SeekMode::approximate &&
511+
!streamMetadata.averageFps.has_value()) {
512+
throw std::runtime_error(
513+
"Seek mode is approximate, but stream " +
514+
std::to_string(activeStreamIndex_) +
515+
" does not have an average fps in its metadata.");
516+
}
517+
518+
auto& streamInfo = streamInfos_[activeStreamIndex_];
519+
streamInfo.videoStreamOptions = videoStreamOptions;
520+
521+
streamMetadata.width = streamInfo.codecContext->width;
522+
streamMetadata.height = streamInfo.codecContext->height;
500523

501524
// By default, we want to use swscale for color conversion because it is
502525
// faster. However, it has width requirements, so we may need to fall back
@@ -505,7 +528,7 @@ void VideoDecoder::addVideoStream(
505528
// swscale's width requirements to be violated. We don't expose the ability to
506529
// choose color conversion library publicly; we only use this ability
507530
// internally.
508-
int width = videoStreamOptions.width.value_or(codecContext->width);
531+
int width = videoStreamOptions.width.value_or(streamInfo.codecContext->width);
509532

510533
// swscale requires widths to be multiples of 32:
511534
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
@@ -518,6 +541,21 @@ void VideoDecoder::addVideoStream(
518541
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
519542
}
520543

544+
void VideoDecoder::addAudioStream(int streamIndex) {
545+
TORCH_CHECK(
546+
seekMode_ == SeekMode::approximate,
547+
"seek_mode must be 'approximate' for audio streams.");
548+
549+
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
550+
551+
auto& streamInfo = streamInfos_[activeStreamIndex_];
552+
auto& streamMetadata =
553+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
554+
streamMetadata.sampleRate =
555+
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
556+
streamMetadata.numChannels = getNumChannels(streamInfo.codecContext);
557+
}
558+
521559
// --------------------------------------------------------------------------
522560
// HIGH-LEVEL DECODING ENTRY-POINTS
523561
// --------------------------------------------------------------------------
@@ -546,7 +584,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
546584
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
547585
int64_t frameIndex,
548586
std::optional<torch::Tensor> preAllocatedOutputTensor) {
549-
validateActiveStream();
587+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
550588

551589
const auto& streamInfo = streamInfos_[activeStreamIndex_];
552590
const auto& streamMetadata =
@@ -560,7 +598,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
560598

561599
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
562600
const std::vector<int64_t>& frameIndices) {
563-
validateActiveStream();
601+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
564602

565603
auto indicesAreSorted =
566604
std::is_sorted(frameIndices.begin(), frameIndices.end());
@@ -619,7 +657,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
619657

620658
VideoDecoder::FrameBatchOutput
621659
VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
622-
validateActiveStream();
660+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
623661

624662
const auto& streamMetadata =
625663
containerMetadata_.allStreamMetadata[activeStreamIndex_];
@@ -690,7 +728,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
690728

691729
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
692730
const std::vector<double>& timestamps) {
693-
validateActiveStream();
731+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
694732

695733
const auto& streamMetadata =
696734
containerMetadata_.allStreamMetadata[activeStreamIndex_];
@@ -721,7 +759,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
721759
VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
722760
double startSeconds,
723761
double stopSeconds) {
724-
validateActiveStream();
762+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
725763

726764
const auto& streamMetadata =
727765
containerMetadata_.allStreamMetadata[activeStreamIndex_];
@@ -860,7 +898,7 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
860898
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
861899
// the comment of canWeAvoidSeeking() for details.
862900
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
863-
validateActiveStream();
901+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
864902
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
865903

866904
int64_t desiredPts =
@@ -907,7 +945,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
907945

908946
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
909947
std::function<bool(AVFrame*)> filterFunction) {
910-
validateActiveStream();
948+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
911949

912950
resetDecodeStats();
913951

@@ -1587,7 +1625,8 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) {
15871625
// VALIDATION UTILS
15881626
// --------------------------------------------------------------------------
15891627

1590-
void VideoDecoder::validateActiveStream() {
1628+
void VideoDecoder::validateActiveStream(
1629+
std::optional<AVMediaType> avMediaType) {
15911630
auto errorMsg =
15921631
"Provided stream index=" + std::to_string(activeStreamIndex_) +
15931632
" was not previously added.";
@@ -1601,6 +1640,14 @@ void VideoDecoder::validateActiveStream() {
16011640
"Invalid stream index=" + std::to_string(activeStreamIndex_) +
16021641
"; valid indices are in the range [0, " +
16031642
std::to_string(allStreamMetadataSize) + ").");
1643+
1644+
if (avMediaType.has_value()) {
1645+
TORCH_CHECK(
1646+
streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value(),
1647+
"The method you called isn't supported. ",
1648+
"If you're seeing this error, you are probably trying to call an ",
1649+
"unsupported method on an audio stream.");
1650+
}
16041651
}
16051652

16061653
void VideoDecoder::validateScannedAllStreams(const std::string& msg) {
@@ -1648,7 +1695,7 @@ void VideoDecoder::resetDecodeStats() {
16481695
}
16491696

16501697
double VideoDecoder::getPtsSecondsForFrame(int64_t frameIndex) {
1651-
validateActiveStream();
1698+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
16521699
validateScannedAllStreams("getPtsSecondsForFrame");
16531700

16541701
const auto& streamInfo = streamInfos_[activeStreamIndex_];

0 commit comments

Comments
 (0)