Skip to content

Commit 9abd343

Browse files
authored
Merge branch 'main' into remove-stream-video-options
2 parents 23b3cf7 + 611421e commit 9abd343

14 files changed

+1110
-63
lines changed

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ def __init__(
2525
source: Union[str, Path, bytes, Tensor],
2626
*,
2727
stream_index: Optional[int] = None,
28+
sample_rate: Optional[int] = None,
2829
):
2930
self._decoder = create_decoder(source=source, seek_mode="approximate")
3031

31-
core.add_audio_stream(self._decoder, stream_index=stream_index)
32+
core.add_audio_stream(
33+
self._decoder, stream_index=stream_index, sample_rate=sample_rate
34+
)
3235

3336
(
3437
self.metadata,
@@ -39,6 +42,9 @@ def __init__(
3942
decoder=self._decoder, stream_index=stream_index, media_type="audio"
4043
)
4144
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
45+
self._desired_sample_rate = (
46+
sample_rate if sample_rate is not None else self.metadata.sample_rate
47+
)
4248

4349
def get_samples_played_in_range(
4450
self, start_seconds: float, stop_seconds: Optional[float] = None
@@ -75,11 +81,7 @@ def get_samples_played_in_range(
7581
# So we do some basic math to figure out the position of the view that
7682
# we'll return.
7783

78-
# TODO: sample_rate is either the original one from metadata, or the
79-
# user-specified one (NIY)
80-
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
81-
sample_rate = self.metadata.sample_rate
82-
84+
sample_rate = self._desired_sample_rate
8385
# TODO: metadata's sample_rate should probably not be Optional
8486
assert sample_rate is not None # mypy.
8587

@@ -94,7 +96,7 @@ def get_samples_played_in_range(
9496
output_pts_seconds = first_pts
9597

9698
num_samples = frames.shape[1]
97-
last_pts = first_pts + num_samples / self.metadata.sample_rate
99+
last_pts = first_pts + num_samples / sample_rate
98100
if stop_seconds is not None and stop_seconds < last_pts:
99101
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate)
100102
else:

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,21 @@ void setChannelLayout(
8686

8787
SwrContext* allocateSwrContext(
8888
UniqueAVCodecContext& avCodecContext,
89-
int sampleRate,
9089
AVSampleFormat sourceSampleFormat,
91-
AVSampleFormat desiredSampleFormat) {
90+
AVSampleFormat desiredSampleFormat,
91+
int sourceSampleRate,
92+
int desiredSampleRate) {
9293
SwrContext* swrContext = nullptr;
9394
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
9495
AVChannelLayout layout = avCodecContext->ch_layout;
9596
auto status = swr_alloc_set_opts2(
9697
&swrContext,
9798
&layout,
9899
desiredSampleFormat,
99-
sampleRate,
100+
desiredSampleRate,
100101
&layout,
101102
sourceSampleFormat,
102-
sampleRate,
103+
sourceSampleRate,
103104
0,
104105
nullptr);
105106

@@ -113,10 +114,10 @@ SwrContext* allocateSwrContext(
113114
nullptr,
114115
layout,
115116
desiredSampleFormat,
116-
sampleRate,
117+
desiredSampleRate,
117118
layout,
118119
sourceSampleFormat,
119-
sampleRate,
120+
sourceSampleRate,
120121
0,
121122
nullptr);
122123
#endif

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,10 @@ void setChannelLayout(
149149
const UniqueAVFrame& srcAVFrame);
150150
SwrContext* allocateSwrContext(
151151
UniqueAVCodecContext& avCodecContext,
152-
int sampleRate,
153152
AVSampleFormat sourceSampleFormat,
154-
AVSampleFormat desiredSampleFormat);
153+
AVSampleFormat desiredSampleFormat,
154+
int sourceSampleRate,
155+
int desiredSampleRate);
155156

156157
// Returns true if sws_scale can handle unaligned data.
157158
bool canSwsScaleHandleUnalignedData();

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 107 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,18 @@ void VideoDecoder::addVideoStream(
512512
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
513513
}
514514

515-
void VideoDecoder::addAudioStream(int streamIndex) {
515+
void VideoDecoder::addAudioStream(
516+
int streamIndex,
517+
const AudioStreamOptions& audioStreamOptions) {
516518
TORCH_CHECK(
517519
seekMode_ == SeekMode::approximate,
518520
"seek_mode must be 'approximate' for audio streams.");
519521

520522
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
521523

522524
auto& streamInfo = streamInfos_[activeStreamIndex_];
525+
streamInfo.audioStreamOptions = audioStreamOptions;
526+
523527
auto& streamMetadata =
524528
containerMetadata_.allStreamMetadata[activeStreamIndex_];
525529
streamMetadata.sampleRate =
@@ -879,6 +883,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
879883
(stopPts <= lastDecodedAvFrameEnd);
880884
}
881885

886+
auto lastSamples = maybeFlushSwrBuffers();
887+
if (lastSamples.has_value()) {
888+
frames.push_back(*lastSamples);
889+
}
890+
882891
return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds};
883892
}
884893

@@ -1132,8 +1141,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
11321141
getDuration(avFrame),
11331142
formatContext_->streams[activeStreamIndex_]->time_base);
11341143
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1135-
convertAudioAVFrameToFrameOutputOnCPU(
1136-
avFrame, frameOutput, preAllocatedOutputTensor);
1144+
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
11371145
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
11381146
convertAVFrameToFrameOutputOnCPU(
11391147
avFrame, frameOutput, preAllocatedOutputTensor);
@@ -1311,24 +1319,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13111319

13121320
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13131321
UniqueAVFrame& srcAVFrame,
1314-
FrameOutput& frameOutput,
1315-
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1316-
TORCH_CHECK(
1317-
!preAllocatedOutputTensor.has_value(),
1318-
"pre-allocated audio tensor not supported yet.");
1319-
1322+
FrameOutput& frameOutput) {
13201323
AVSampleFormat sourceSampleFormat =
13211324
static_cast<AVSampleFormat>(srcAVFrame->format);
13221325
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13231326

1327+
int sourceSampleRate = srcAVFrame->sample_rate;
1328+
int desiredSampleRate =
1329+
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or(
1330+
sourceSampleRate);
1331+
1332+
bool mustConvert =
1333+
(sourceSampleFormat != desiredSampleFormat ||
1334+
sourceSampleRate != desiredSampleRate);
1335+
13241336
UniqueAVFrame convertedAVFrame;
1325-
if (sourceSampleFormat != desiredSampleFormat) {
1326-
convertedAVFrame = convertAudioAVFrameSampleFormat(
1327-
srcAVFrame, sourceSampleFormat, desiredSampleFormat);
1337+
if (mustConvert) {
1338+
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
1339+
srcAVFrame,
1340+
sourceSampleFormat,
1341+
desiredSampleFormat,
1342+
sourceSampleRate,
1343+
desiredSampleRate);
13281344
}
1329-
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1330-
? convertedAVFrame
1331-
: srcAVFrame;
1345+
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
13321346

13331347
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
13341348
TORCH_CHECK(
@@ -1351,55 +1365,110 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13511365
memcpy(
13521366
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel);
13531367
}
1368+
13541369
frameOutput.data = outputData;
13551370
}
13561371

1357-
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat(
1358-
const UniqueAVFrame& avFrame,
1372+
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
1373+
const UniqueAVFrame& srcAVFrame,
13591374
AVSampleFormat sourceSampleFormat,
1360-
AVSampleFormat desiredSampleFormat
1361-
1362-
) {
1375+
AVSampleFormat desiredSampleFormat,
1376+
int sourceSampleRate,
1377+
int desiredSampleRate) {
13631378
auto& streamInfo = streamInfos_[activeStreamIndex_];
1364-
const auto& streamMetadata =
1365-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1366-
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value());
13671379

13681380
if (!streamInfo.swrContext) {
13691381
createSwrContext(
1370-
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1382+
streamInfo,
1383+
sourceSampleFormat,
1384+
desiredSampleFormat,
1385+
sourceSampleRate,
1386+
desiredSampleRate);
13711387
}
13721388

13731389
UniqueAVFrame convertedAVFrame(av_frame_alloc());
13741390
TORCH_CHECK(
13751391
convertedAVFrame,
13761392
"Could not allocate frame for sample format conversion.");
13771393

1378-
setChannelLayout(convertedAVFrame, avFrame);
1394+
setChannelLayout(convertedAVFrame, srcAVFrame);
13791395
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
1380-
convertedAVFrame->sample_rate = avFrame->sample_rate;
1381-
convertedAVFrame->nb_samples = avFrame->nb_samples;
1396+
convertedAVFrame->sample_rate = desiredSampleRate;
1397+
if (sourceSampleRate != desiredSampleRate) {
1398+
// Note that this is an upper bound on the number of output samples.
1399+
// `swr_convert()` will likely not fill convertedAVFrame with that many
1400+
// samples if sample rate conversion is needed. It will buffer the last few
1401+
// ones because those require future samples. That's also why we reset
1402+
// nb_samples after the call to `swr_convert()`.
1403+
// We could also use `swr_get_out_samples()` to determine the number of
1404+
// output samples, but empirically `av_rescale_rnd()` seems to provide a
1405+
// tighter bound.
1406+
convertedAVFrame->nb_samples = av_rescale_rnd(
1407+
swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) +
1408+
srcAVFrame->nb_samples,
1409+
desiredSampleRate,
1410+
sourceSampleRate,
1411+
AV_ROUND_UP);
1412+
} else {
1413+
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
1414+
}
13821415

13831416
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
13841417
TORCH_CHECK(
13851418
status == AVSUCCESS,
13861419
"Could not allocate frame buffers for sample format conversion: ",
13871420
getFFMPEGErrorStringFromErrorCode(status));
13881421

1389-
auto numSampleConverted = swr_convert(
1422+
auto numConvertedSamples = swr_convert(
13901423
streamInfo.swrContext.get(),
13911424
convertedAVFrame->data,
13921425
convertedAVFrame->nb_samples,
1393-
static_cast<const uint8_t**>(const_cast<const uint8_t**>(avFrame->data)),
1394-
avFrame->nb_samples);
1426+
static_cast<const uint8_t**>(
1427+
const_cast<const uint8_t**>(srcAVFrame->data)),
1428+
srcAVFrame->nb_samples);
13951429
TORCH_CHECK(
1396-
numSampleConverted > 0,
1430+
numConvertedSamples > 0,
13971431
"Error in swr_convert: ",
1398-
getFFMPEGErrorStringFromErrorCode(numSampleConverted));
1432+
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));
1433+
1434+
// See comment above about nb_samples
1435+
convertedAVFrame->nb_samples = numConvertedSamples;
13991436

14001437
return convertedAVFrame;
14011438
}
14021439

1440+
std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers() {
1441+
// When sample rate conversion is involved, swresample buffers some of the
1442+
// samples in-between calls to swr_convert (see the libswresample docs).
1443+
// That's because the last few samples in a given frame require future samples
1444+
// from the next frame to be properly converted. This function flushes out the
1445+
// samples that are stored in swresample's buffers.
1446+
auto& streamInfo = streamInfos_[activeStreamIndex_];
1447+
if (!streamInfo.swrContext) {
1448+
return std::nullopt;
1449+
}
1450+
auto numRemainingSamples = // this is an upper bound
1451+
swr_get_out_samples(streamInfo.swrContext.get(), 0);
1452+
1453+
if (numRemainingSamples == 0) {
1454+
return std::nullopt;
1455+
}
1456+
1457+
torch::Tensor lastSamples = torch::empty(
1458+
{getNumChannels(streamInfo.codecContext), numRemainingSamples},
1459+
torch::kFloat32);
1460+
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr());
1461+
1462+
auto actualNumRemainingSamples = swr_convert(
1463+
streamInfo.swrContext.get(),
1464+
&lastSamplesData,
1465+
numRemainingSamples,
1466+
nullptr,
1467+
0);
1468+
return lastSamples.narrow(
1469+
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
1470+
}
1471+
14031472
// --------------------------------------------------------------------------
14041473
// OUTPUT ALLOCATION AND SHAPE CONVERSION
14051474
// --------------------------------------------------------------------------
@@ -1635,14 +1704,16 @@ void VideoDecoder::createSwsContext(
16351704

16361705
void VideoDecoder::createSwrContext(
16371706
StreamInfo& streamInfo,
1638-
int sampleRate,
16391707
AVSampleFormat sourceSampleFormat,
1640-
AVSampleFormat desiredSampleFormat) {
1708+
AVSampleFormat desiredSampleFormat,
1709+
int sourceSampleRate,
1710+
int desiredSampleRate) {
16411711
auto swrContext = allocateSwrContext(
16421712
streamInfo.codecContext,
1643-
sampleRate,
16441713
sourceSampleFormat,
1645-
desiredSampleFormat);
1714+
desiredSampleFormat,
1715+
sourceSampleRate,
1716+
desiredSampleRate);
16461717

16471718
auto status = swr_init(swrContext);
16481719
TORCH_CHECK(

0 commit comments

Comments
 (0)