Skip to content

Commit f6ecd32

Browse files
committed
WIP
1 parent 979e72a commit f6ecd32

File tree

3 files changed

+122
-91
lines changed

3 files changed

+122
-91
lines changed

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ extern "C" {
2222
#include <libavutil/opt.h>
2323
#include <libavutil/pixfmt.h>
2424
#include <libavutil/version.h>
25+
#include <libswresample/swresample.h>
2526
#include <libswscale/swscale.h>
2627
}
2728

@@ -67,6 +68,8 @@ using UniqueAVIOContext = std::
6768
unique_ptr<AVIOContext, Deleterp<AVIOContext, void, avio_context_free>>;
6869
using UniqueSwsContext =
6970
std::unique_ptr<SwsContext, Deleter<SwsContext, void, sws_freeContext>>;
71+
using UniqueSwrContext =
72+
std::unique_ptr<SwrContext, Deleterp<SwrContext, void, swr_free>>;
7073

7174
// These 2 classes share the same underlying AVPacket object. They are meant to
7275
// be used in tandem, like so:

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 107 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,108 +1346,91 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13461346
!preAllocatedOutputTensor.has_value(),
13471347
"pre-allocated audio tensor not supported yet.");
13481348

1349-
const AVFrame* avFrame = avFrameStream.avFrame.get();
1349+
const UniqueAVFrame& avFrame = avFrameStream.avFrame;
13501350

13511351
AVSampleFormat sourceSampleFormat =
13521352
static_cast<AVSampleFormat>(avFrame->format);
13531353
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1354-
1355-
AVFrame* output_frame = nullptr;
1356-
SwrContext* swr_ctx = NULL; // TODO RAII
1357-
if (sourceSampleFormat != desiredSampleFormat) {
1358-
1359-
const auto& streamInfo = streamInfos_[activeStreamIndex_];
1360-
const auto& streamMetadata =
1361-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1362-
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value());
1363-
1364-
AVChannelLayout layout = streamInfo.codecContext->ch_layout;
1365-
1366-
int status = swr_alloc_set_opts2(
1367-
&swr_ctx,
1368-
&layout,
1369-
desiredSampleFormat,
1370-
sampleRate,
1371-
&layout,
1372-
sourceSampleFormat,
1373-
sampleRate,
1374-
0,
1375-
NULL);
1376-
1377-
TORCH_CHECK(status == 0, "IS NULL");
1378-
1379-
if (swr_init(swr_ctx) < 0) {
1380-
swr_free(&swr_ctx);
1381-
TORCH_CHECK(false, "Failed to initialize the resampling context\n");
1382-
}
1383-
1384-
// Allocate output frame
1385-
output_frame = av_frame_alloc();
1386-
if (!output_frame) {
1387-
swr_free(&swr_ctx);
1388-
TORCH_CHECK(false, "Could not allocate output frame\n");
1389-
}
1390-
output_frame->ch_layout = layout;
1391-
output_frame->sample_rate = sampleRate;
1392-
output_frame->format = desiredSampleFormat;
1393-
1394-
output_frame->nb_samples = av_rescale_rnd(
1395-
swr_get_delay(swr_ctx, sampleRate) + avFrame->nb_samples,
1396-
sampleRate,
1397-
sampleRate,
1398-
AV_ROUND_UP);
1399-
1400-
if (av_frame_get_buffer(output_frame, 0) < 0) {
1401-
av_frame_free(&output_frame);
1402-
swr_free(&swr_ctx);
1403-
TORCH_CHECK(false, "Could not allocate output frame samples");
1404-
}
1405-
1406-
int ret = swr_convert(
1407-
swr_ctx,
1408-
output_frame->data,
1409-
output_frame->nb_samples,
1410-
(const uint8_t**)avFrame->data,
1411-
avFrame->nb_samples);
1412-
if (ret < 0) {
1413-
av_frame_free(&output_frame);
1414-
swr_free(&swr_ctx);
1415-
TORCH_CHECK(false, "Error while converting\n");
1416-
}
1417-
1418-
avFrame = output_frame; // lmao
1354+
AVFrame* rawAVFrame = nullptr;
1355+
UniqueAVFrame convertedAVFrame;
1356+
if (sourceSampleFormat == desiredSampleFormat) {
1357+
rawAVFrame = avFrame.get();
1358+
} else {
1359+
convertedAVFrame = convertAudioAVFrameSampleFormat(
1360+
avFrame, sourceSampleFormat, desiredSampleFormat);
1361+
rawAVFrame = convertedAVFrame.get();
14191362
}
14201363

1421-
auto numSamples = avFrame->nb_samples; // per channel
1422-
auto numChannels = getNumChannels(avFrame);
1364+
AVSampleFormat format = static_cast<AVSampleFormat>(rawAVFrame->format);
1365+
TORCH_CHECK(
1366+
format == desiredSampleFormat,
1367+
"Something went wrong, the frame didn't get converted to the desired format. ",
1368+
"Desired format = ",
1369+
av_get_sample_fmt_name(desiredSampleFormat),
1370+
"source format = ",
1371+
av_get_sample_fmt_name(format));
1372+
1373+
auto numSamples = rawAVFrame->nb_samples; // per channel
1374+
auto numChannels = getNumChannels(rawAVFrame);
14231375
torch::Tensor outputData =
14241376
torch::empty({numChannels, numSamples}, torch::kFloat32);
14251377

1426-
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1427-
// TODO-AUDIO Implement all formats.
1428-
switch (format) {
1429-
case AV_SAMPLE_FMT_FLTP: {
1430-
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
1431-
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1432-
for (auto channel = 0; channel < numChannels;
1433-
++channel, outputChannelData += numBytesPerChannel) {
1434-
memcpy(
1435-
outputChannelData,
1436-
avFrame->extended_data[channel],
1437-
numBytesPerChannel);
1438-
}
1439-
break;
1440-
}
1441-
default:
1442-
TORCH_CHECK(
1443-
false,
1444-
"Unsupported audio format (yet!): ",
1445-
av_get_sample_fmt_name(format));
1378+
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
1379+
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1380+
for (auto channel = 0; channel < numChannels;
1381+
++channel, outputChannelData += numBytesPerChannel) {
1382+
memcpy(
1383+
outputChannelData,
1384+
rawAVFrame->extended_data[channel],
1385+
numBytesPerChannel);
14461386
}
14471387
frameOutput.data = outputData;
1448-
// TODO
1449-
av_frame_free(&output_frame);
1450-
swr_free(&swr_ctx);
1388+
}
1389+
1390+
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat(
1391+
const UniqueAVFrame& avFrame,
1392+
AVSampleFormat sourceSampleFormat,
1393+
AVSampleFormat desiredSampleFormat
1394+
1395+
) {
1396+
auto& streamInfo = streamInfos_[activeStreamIndex_];
1397+
const auto& streamMetadata =
1398+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1399+
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value());
1400+
1401+
if (!streamInfo.swrContext) {
1402+
createSwrContext(
1403+
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1404+
}
1405+
1406+
UniqueAVFrame convertedAVFrame(av_frame_alloc());
1407+
TORCH_CHECK(
1408+
convertedAVFrame,
1409+
"Could not allocate frame for sample format conversion.");
1410+
1411+
convertedAVFrame->ch_layout = avFrame->ch_layout;
1412+
convertedAVFrame->sample_rate = avFrame->sample_rate;
1413+
convertedAVFrame->nb_samples = avFrame->nb_samples;
1414+
convertedAVFrame->format = desiredSampleFormat;
1415+
1416+
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
1417+
TORCH_CHECK(
1418+
status == AVSUCCESS,
1419+
"Could not allocate frame buffers for sample format conversion: ",
1420+
getFFMPEGErrorStringFromErrorCode(status));
1421+
1422+
auto numSampleConverted = swr_convert(
1423+
streamInfo.swrContext.get(),
1424+
convertedAVFrame->data,
1425+
convertedAVFrame->nb_samples,
1426+
(const uint8_t**)avFrame->data,
1427+
avFrame->nb_samples);
1428+
TORCH_CHECK(
1429+
numSampleConverted > 0,
1430+
"Error in swr_convert: ",
1431+
getFFMPEGErrorStringFromErrorCode(numSampleConverted));
1432+
1433+
return convertedAVFrame;
14511434
}
14521435

14531436
// --------------------------------------------------------------------------
@@ -1683,6 +1666,39 @@ void VideoDecoder::createSwsContext(
16831666
streamInfo.swsContext.reset(swsContext);
16841667
}
16851668

1669+
void VideoDecoder::createSwrContext(
1670+
StreamInfo& streamInfo,
1671+
int sampleRate,
1672+
AVSampleFormat sourceSampleFormat,
1673+
AVSampleFormat desiredSampleFormat) {
1674+
SwrContext* swrContext = NULL;
1675+
1676+
AVChannelLayout layout = streamInfo.codecContext->ch_layout;
1677+
1678+
auto status = swr_alloc_set_opts2(
1679+
&swrContext,
1680+
&layout,
1681+
desiredSampleFormat,
1682+
sampleRate,
1683+
&layout,
1684+
sourceSampleFormat,
1685+
sampleRate,
1686+
0,
1687+
NULL);
1688+
1689+
TORCH_CHECK(
1690+
status == AVSUCCESS,
1691+
"Couldn't create SwrContext: ",
1692+
getFFMPEGErrorStringFromErrorCode(status));
1693+
1694+
status = swr_init(swrContext);
1695+
TORCH_CHECK(
1696+
status == AVSUCCESS,
1697+
"Couldn't initialize SwrContext: ",
1698+
getFFMPEGErrorStringFromErrorCode(status));
1699+
streamInfo.swrContext.reset(swrContext);
1700+
}
1701+
16861702
// --------------------------------------------------------------------------
16871703
// PTS <-> INDEX CONVERSIONS
16881704
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class VideoDecoder {
354354
FilterGraphContext filterGraphContext;
355355
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
356356
UniqueSwsContext swsContext;
357+
UniqueSwrContext swrContext;
357358

358359
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
359360
// be created before decoding a new frame.
@@ -400,6 +401,11 @@ class VideoDecoder {
400401
const AVFrame* avFrame,
401402
torch::Tensor& outputTensor);
402403

404+
UniqueAVFrame convertAudioAVFrameSampleFormat(
405+
const UniqueAVFrame& avFrame,
406+
AVSampleFormat sourceSampleFormat,
407+
AVSampleFormat desiredSampleFormat);
408+
403409
// --------------------------------------------------------------------------
404410
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
405411
// --------------------------------------------------------------------------
@@ -414,6 +420,12 @@ class VideoDecoder {
414420
const DecodedFrameContext& frameContext,
415421
const enum AVColorSpace colorspace);
416422

423+
void createSwrContext(
424+
StreamInfo& streamInfo,
425+
int sampleRate,
426+
AVSampleFormat sourceSampleFormat,
427+
AVSampleFormat desiredSampleFormat);
428+
417429
// --------------------------------------------------------------------------
418430
// PTS <-> INDEX CONVERSIONS
419431
// --------------------------------------------------------------------------

0 commit comments

Comments
 (0)