Skip to content

Commit 330b4d5

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into file_like
2 parents 5134aff + 5713507 commit 330b4d5

File tree

8 files changed

+199
-25
lines changed

8 files changed

+199
-25
lines changed

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,10 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3})
142142
${CMAKE_CURRENT_SOURCE_DIR}/fetch_and_expose_non_gpl_ffmpeg_libs.cmake
143143
)
144144

145-
make_torchcodec_libraries(4 ffmpeg4 $ffmpeg4_INCLUDE_DIRS)
146145
make_torchcodec_libraries(7 ffmpeg7 $ffmpeg7_INCLUDE_DIRs)
147146
make_torchcodec_libraries(6 ffmpeg6 $ffmpeg6_INCLUDE_DIRS)
147+
make_torchcodec_libraries(4 ffmpeg4 $ffmpeg4_INCLUDE_DIRS)
148148
make_torchcodec_libraries(5 ffmpeg5 $ffmpeg5_INCLUDE_DIRS)
149-
150149
else()
151150
message(
152151
STATUS
@@ -162,6 +161,7 @@ else()
162161
libavformat
163162
libavcodec
164163
libavutil
164+
libswresample
165165
libswscale
166166
)
167167

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63-
int getNumChannels(const AVFrame* avFrame) {
63+
int getNumChannels(const UniqueAVFrame& avFrame) {
6464
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
6565
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
6666
return avFrame->ch_layout.nb_channels;
@@ -78,4 +78,55 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
7878
#endif
7979
}
8080

81+
void setChannelLayout(
82+
UniqueAVFrame& dstAVFrame,
83+
const UniqueAVFrame& srcAVFrame) {
84+
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
85+
dstAVFrame->ch_layout = srcAVFrame->ch_layout;
86+
#else
87+
dstAVFrame->channel_layout = srcAVFrame->channel_layout;
88+
#endif
89+
}
90+
91+
SwrContext* allocateSwrContext(
92+
UniqueAVCodecContext& avCodecContext,
93+
int sampleRate,
94+
AVSampleFormat sourceSampleFormat,
95+
AVSampleFormat desiredSampleFormat) {
96+
SwrContext* swrContext = nullptr;
97+
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
98+
AVChannelLayout layout = avCodecContext->ch_layout;
99+
auto status = swr_alloc_set_opts2(
100+
&swrContext,
101+
&layout,
102+
desiredSampleFormat,
103+
sampleRate,
104+
&layout,
105+
sourceSampleFormat,
106+
sampleRate,
107+
0,
108+
nullptr);
109+
110+
TORCH_CHECK(
111+
status == AVSUCCESS,
112+
"Couldn't create SwrContext: ",
113+
getFFMPEGErrorStringFromErrorCode(status));
114+
#else
115+
int64_t layout = static_cast<int64_t>(avCodecContext->channel_layout);
116+
swrContext = swr_alloc_set_opts(
117+
nullptr,
118+
layout,
119+
desiredSampleFormat,
120+
sampleRate,
121+
layout,
122+
sourceSampleFormat,
123+
sampleRate,
124+
0,
125+
nullptr);
126+
#endif
127+
128+
TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext");
129+
return swrContext;
130+
}
131+
81132
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ extern "C" {
2222
#include <libavutil/opt.h>
2323
#include <libavutil/pixfmt.h>
2424
#include <libavutil/version.h>
25+
#include <libswresample/swresample.h>
2526
#include <libswscale/swscale.h>
2627
}
2728

@@ -67,6 +68,8 @@ using UniqueAVIOContext = std::
6768
unique_ptr<AVIOContext, Deleterp<AVIOContext, void, avio_context_free>>;
6869
using UniqueSwsContext =
6970
std::unique_ptr<SwsContext, Deleter<SwsContext, void, sws_freeContext>>;
71+
using UniqueSwrContext =
72+
std::unique_ptr<SwrContext, Deleterp<SwrContext, void, swr_free>>;
7073

7174
// These 2 classes share the same underlying AVPacket object. They are meant to
7275
// be used in tandem, like so:
@@ -139,9 +142,18 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139142
int64_t getDuration(const UniqueAVFrame& frame);
140143
int64_t getDuration(const AVFrame* frame);
141144

142-
int getNumChannels(const AVFrame* avFrame);
145+
int getNumChannels(const UniqueAVFrame& avFrame);
143146
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
144147

148+
void setChannelLayout(
149+
UniqueAVFrame& dstAVFrame,
150+
const UniqueAVFrame& srcAVFrame);
151+
SwrContext* allocateSwrContext(
152+
UniqueAVCodecContext& avCodecContext,
153+
int sampleRate,
154+
AVSampleFormat sourceSampleFormat,
155+
AVSampleFormat desiredSampleFormat);
156+
145157
// Returns true if sws_scale can handle unaligned data.
146158
bool canSwsScaleHandleUnalignedData();
147159

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 99 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ extern "C" {
2323
#include <libavutil/imgutils.h>
2424
#include <libavutil/log.h>
2525
#include <libavutil/pixdesc.h>
26+
#include <libswresample/swresample.h>
2627
#include <libswscale/swscale.h>
2728
}
2829

@@ -557,6 +558,12 @@ void VideoDecoder::addAudioStream(int streamIndex) {
557558
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
558559
streamMetadata.numChannels =
559560
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
561+
562+
// FFmpeg docs say that the decoder will try to decode natively in this
563+
// format, if it can. Docs don't say what the decoder does when it doesn't
564+
// support that format, but it looks like it does nothing, so this probably
565+
// doesn't hurt.
566+
streamInfo.codecContext->request_sample_fmt = AV_SAMPLE_FMT_FLTP;
560567
}
561568

562569
// --------------------------------------------------------------------------
@@ -1348,37 +1355,89 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13481355
!preAllocatedOutputTensor.has_value(),
13491356
"pre-allocated audio tensor not supported yet.");
13501357

1351-
const AVFrame* avFrame = avFrameStream.avFrame.get();
1358+
AVSampleFormat sourceSampleFormat =
1359+
static_cast<AVSampleFormat>(avFrameStream.avFrame->format);
1360+
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1361+
1362+
UniqueAVFrame convertedAVFrame;
1363+
if (sourceSampleFormat != desiredSampleFormat) {
1364+
convertedAVFrame = convertAudioAVFrameSampleFormat(
1365+
avFrameStream.avFrame, sourceSampleFormat, desiredSampleFormat);
1366+
}
1367+
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1368+
? convertedAVFrame
1369+
: avFrameStream.avFrame;
1370+
1371+
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1372+
TORCH_CHECK(
1373+
format == desiredSampleFormat,
1374+
"Something went wrong, the frame didn't get converted to the desired format. ",
1375+
"Desired format = ",
1376+
av_get_sample_fmt_name(desiredSampleFormat),
1377+
"source format = ",
1378+
av_get_sample_fmt_name(format));
13521379

13531380
auto numSamples = avFrame->nb_samples; // per channel
13541381
auto numChannels = getNumChannels(avFrame);
13551382
torch::Tensor outputData =
13561383
torch::empty({numChannels, numSamples}, torch::kFloat32);
13571384

1358-
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1359-
// TODO-AUDIO Implement all formats.
1360-
switch (format) {
1361-
case AV_SAMPLE_FMT_FLTP: {
1362-
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
1363-
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1364-
for (auto channel = 0; channel < numChannels;
1365-
++channel, outputChannelData += numBytesPerChannel) {
1366-
memcpy(
1367-
outputChannelData,
1368-
avFrame->extended_data[channel],
1369-
numBytesPerChannel);
1370-
}
1371-
break;
1372-
}
1373-
default:
1374-
TORCH_CHECK(
1375-
false,
1376-
"Unsupported audio format (yet!): ",
1377-
av_get_sample_fmt_name(format));
1385+
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
1386+
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1387+
for (auto channel = 0; channel < numChannels;
1388+
++channel, outputChannelData += numBytesPerChannel) {
1389+
memcpy(
1390+
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel);
13781391
}
13791392
frameOutput.data = outputData;
13801393
}
13811394

1395+
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat(
1396+
const UniqueAVFrame& avFrame,
1397+
AVSampleFormat sourceSampleFormat,
1398+
AVSampleFormat desiredSampleFormat
1399+
1400+
) {
1401+
auto& streamInfo = streamInfos_[activeStreamIndex_];
1402+
const auto& streamMetadata =
1403+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1404+
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value());
1405+
1406+
if (!streamInfo.swrContext) {
1407+
createSwrContext(
1408+
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1409+
}
1410+
1411+
UniqueAVFrame convertedAVFrame(av_frame_alloc());
1412+
TORCH_CHECK(
1413+
convertedAVFrame,
1414+
"Could not allocate frame for sample format conversion.");
1415+
1416+
setChannelLayout(convertedAVFrame, avFrame);
1417+
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
1418+
convertedAVFrame->sample_rate = avFrame->sample_rate;
1419+
convertedAVFrame->nb_samples = avFrame->nb_samples;
1420+
1421+
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
1422+
TORCH_CHECK(
1423+
status == AVSUCCESS,
1424+
"Could not allocate frame buffers for sample format conversion: ",
1425+
getFFMPEGErrorStringFromErrorCode(status));
1426+
1427+
auto numSampleConverted = swr_convert(
1428+
streamInfo.swrContext.get(),
1429+
convertedAVFrame->data,
1430+
convertedAVFrame->nb_samples,
1431+
static_cast<const uint8_t**>(const_cast<const uint8_t**>(avFrame->data)),
1432+
avFrame->nb_samples);
1433+
TORCH_CHECK(
1434+
numSampleConverted > 0,
1435+
"Error in swr_convert: ",
1436+
getFFMPEGErrorStringFromErrorCode(numSampleConverted));
1437+
1438+
return convertedAVFrame;
1439+
}
1440+
13821441
// --------------------------------------------------------------------------
13831442
// OUTPUT ALLOCATION AND SHAPE CONVERSION
13841443
// --------------------------------------------------------------------------
@@ -1612,6 +1671,25 @@ void VideoDecoder::createSwsContext(
16121671
streamInfo.swsContext.reset(swsContext);
16131672
}
16141673

1674+
void VideoDecoder::createSwrContext(
1675+
StreamInfo& streamInfo,
1676+
int sampleRate,
1677+
AVSampleFormat sourceSampleFormat,
1678+
AVSampleFormat desiredSampleFormat) {
1679+
auto swrContext = allocateSwrContext(
1680+
streamInfo.codecContext,
1681+
sampleRate,
1682+
sourceSampleFormat,
1683+
desiredSampleFormat);
1684+
1685+
auto status = swr_init(swrContext);
1686+
TORCH_CHECK(
1687+
status == AVSUCCESS,
1688+
"Couldn't initialize SwrContext: ",
1689+
getFFMPEGErrorStringFromErrorCode(status));
1690+
streamInfo.swrContext.reset(swrContext);
1691+
}
1692+
16151693
// --------------------------------------------------------------------------
16161694
// PTS <-> INDEX CONVERSIONS
16171695
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ class VideoDecoder {
357357
FilterGraphContext filterGraphContext;
358358
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
359359
UniqueSwsContext swsContext;
360+
UniqueSwrContext swrContext;
360361

361362
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
362363
// be created before decoding a new frame.
@@ -404,6 +405,11 @@ class VideoDecoder {
404405
const AVFrame* avFrame,
405406
torch::Tensor& outputTensor);
406407

408+
UniqueAVFrame convertAudioAVFrameSampleFormat(
409+
const UniqueAVFrame& avFrame,
410+
AVSampleFormat sourceSampleFormat,
411+
AVSampleFormat desiredSampleFormat);
412+
407413
// --------------------------------------------------------------------------
408414
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
409415
// --------------------------------------------------------------------------
@@ -418,6 +424,12 @@ class VideoDecoder {
418424
const DecodedFrameContext& frameContext,
419425
const enum AVColorSpace colorspace);
420426

427+
void createSwrContext(
428+
StreamInfo& streamInfo,
429+
int sampleRate,
430+
AVSampleFormat sourceSampleFormat,
431+
AVSampleFormat desiredSampleFormat);
432+
421433
// --------------------------------------------------------------------------
422434
// PTS <-> INDEX CONVERSIONS
423435
// --------------------------------------------------------------------------

test/decoders/test_decoders.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,3 +1070,21 @@ def test_frame_start_is_not_zero(self):
10701070

10711071
reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
10721072
torch.testing.assert_close(samples.data, reference_frames)
1073+
1074+
def test_single_channel(self):
1075+
asset = SINE_MONO_S32
1076+
decoder = AudioDecoder(asset.path)
1077+
1078+
samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=2)
1079+
assert samples.data.shape[0] == asset.num_channels == 1
1080+
1081+
def test_format_conversion(self):
1082+
asset = SINE_MONO_S32
1083+
decoder = AudioDecoder(asset.path)
1084+
assert decoder.metadata.sample_format == asset.sample_format == "s32"
1085+
1086+
all_samples = decoder.get_samples_played_in_range(start_seconds=0)
1087+
assert all_samples.data.dtype == torch.float32
1088+
1089+
reference_frames = asset.get_frame_data_by_range(start=0, stop=asset.num_frames)
1090+
torch.testing.assert_close(all_samples.data, reference_frames)
266 KB
Binary file not shown.

test/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ def sample_format(self) -> str:
444444
},
445445
)
446446

447+
# Note that the file itself is s32 sample format, but the reference frames are
448+
# stored as fltp. We can add the s32 original reference frames once we support
449+
# decoding to non-fltp format, but for now we don't need to.
447450
SINE_MONO_S32 = TestAudio(
448451
filename="sine_mono_s32.wav",
449452
default_stream_index=0,

0 commit comments

Comments
 (0)