Skip to content

Commit 1e26576

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into file_like_tutorial
2 parents 055e9ca + 2c137e7 commit 1e26576

File tree

8 files changed

+224
-159
lines changed

8 files changed

+224
-159
lines changed

src/torchcodec/_core/DeviceInterface.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
namespace facebook::torchcodec {
1212

1313
namespace {
14+
using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>;
1415
std::mutex g_interface_mutex;
15-
std::map<torch::DeviceType, CreateDeviceInterfaceFn> g_interface_map;
16+
std::unique_ptr<DeviceInterfaceMap> g_interface_map;
1617

1718
std::string getDeviceType(const std::string& device) {
1819
size_t pos = device.find(':');
@@ -28,11 +29,18 @@ bool registerDeviceInterface(
2829
torch::DeviceType deviceType,
2930
CreateDeviceInterfaceFn createInterface) {
3031
std::scoped_lock lock(g_interface_mutex);
32+
if (!g_interface_map) {
33+
// We delay this initialization until runtime to avoid the Static
34+
// Initialization Order Fiasco:
35+
//
36+
// https://en.cppreference.com/w/cpp/language/siof
37+
g_interface_map = std::make_unique<DeviceInterfaceMap>();
38+
}
3139
TORCH_CHECK(
32-
g_interface_map.find(deviceType) == g_interface_map.end(),
40+
g_interface_map->find(deviceType) == g_interface_map->end(),
3341
"Device interface already registered for ",
3442
deviceType);
35-
g_interface_map.insert({deviceType, createInterface});
43+
g_interface_map->insert({deviceType, createInterface});
3644
return true;
3745
}
3846

@@ -45,14 +53,16 @@ torch::Device createTorchDevice(const std::string device) {
4553
std::scoped_lock lock(g_interface_mutex);
4654
std::string deviceType = getDeviceType(device);
4755
auto deviceInterface = std::find_if(
48-
g_interface_map.begin(),
49-
g_interface_map.end(),
56+
g_interface_map->begin(),
57+
g_interface_map->end(),
5058
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
5159
return device.rfind(
5260
torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
5361
});
5462
TORCH_CHECK(
55-
deviceInterface != g_interface_map.end(), "Unsupported device: ", device);
63+
deviceInterface != g_interface_map->end(),
64+
"Unsupported device: ",
65+
device);
5666

5767
return torch::Device(device);
5868
}
@@ -67,11 +77,12 @@ std::unique_ptr<DeviceInterface> createDeviceInterface(
6777

6878
std::scoped_lock lock(g_interface_mutex);
6979
TORCH_CHECK(
70-
g_interface_map.find(deviceType) != g_interface_map.end(),
80+
g_interface_map->find(deviceType) != g_interface_map->end(),
7181
"Unsupported device: ",
7282
device);
7383

74-
return std::unique_ptr<DeviceInterface>(g_interface_map[deviceType](device));
84+
return std::unique_ptr<DeviceInterface>(
85+
(*g_interface_map)[deviceType](device));
7586
}
7687

7788
} // namespace facebook::torchcodec

src/torchcodec/_core/Encoder.cpp

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,44 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
3333
supportedRates.str());
3434
}
3535

36+
static const std::vector<AVSampleFormat> preferredFormatsOrder = {
37+
AV_SAMPLE_FMT_FLTP,
38+
AV_SAMPLE_FMT_FLT,
39+
AV_SAMPLE_FMT_DBLP,
40+
AV_SAMPLE_FMT_DBL,
41+
AV_SAMPLE_FMT_S64P,
42+
AV_SAMPLE_FMT_S64,
43+
AV_SAMPLE_FMT_S32P,
44+
AV_SAMPLE_FMT_S32,
45+
AV_SAMPLE_FMT_S16P,
46+
AV_SAMPLE_FMT_S16,
47+
AV_SAMPLE_FMT_U8P,
48+
AV_SAMPLE_FMT_U8};
49+
50+
AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
51+
// Find a sample format that the encoder supports. We prefer using FLT[P],
52+
// since this is the format of the input waveform. If FLTP isn't supported
53+
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
54+
// into the format with the highest resolution.
55+
if (avCodec.sample_fmts == nullptr) {
56+
// Can't really validate anything in this case, best we can do is hope that
57+
// FLTP is supported by the encoder. If not, FFmpeg will raise.
58+
return AV_SAMPLE_FMT_FLTP;
59+
}
60+
61+
for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
62+
for (int i = 0; avCodec.sample_fmts[i] != -1; ++i) {
63+
if (avCodec.sample_fmts[i] == preferredFormat) {
64+
return preferredFormat;
65+
}
66+
}
67+
}
68+
// We should always find a match in preferredFormatsOrder, so we should always
69+
// return earlier. But in the event that a future FFmpeg version defines an
70+
// additional sample format that isn't in preferredFormatsOrder, we fallback:
71+
return avCodec.sample_fmts[0];
72+
}
73+
3674
} // namespace
3775

3876
AudioEncoder::~AudioEncoder() {}
@@ -41,12 +79,14 @@ AudioEncoder::AudioEncoder(
4179
const torch::Tensor wf,
4280
int sampleRate,
4381
std::string_view fileName,
44-
std::optional<int64_t> bit_rate)
82+
std::optional<int64_t> bitRate)
4583
: wf_(wf) {
4684
TORCH_CHECK(
4785
wf_.dtype() == torch::kFloat32,
4886
"waveform must have float32 dtype, got ",
4987
wf_.dtype());
88+
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
89+
// planar (fltp).
5090
TORCH_CHECK(
5191
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
5292

@@ -82,24 +122,20 @@ AudioEncoder::AudioEncoder(
82122
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
83123
avCodecContext_.reset(avCodecContext);
84124

85-
if (bit_rate.has_value()) {
86-
TORCH_CHECK(*bit_rate >= 0, "bit_rate=", *bit_rate, " must be >= 0.");
125+
if (bitRate.has_value()) {
126+
TORCH_CHECK(*bitRate >= 0, "bit_rate=", *bitRate, " must be >= 0.");
87127
}
88128
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
89129
// well when "-b:a" isn't specified.
90-
avCodecContext_->bit_rate = bit_rate.value_or(0);
130+
avCodecContext_->bit_rate = bitRate.value_or(0);
91131

92132
validateSampleRate(*avCodec, sampleRate);
93133
avCodecContext_->sample_rate = sampleRate;
94134

95-
// Note: This is the format of the **input** waveform. This doesn't determine
96-
// the output.
97-
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
98-
// planar.
99-
// TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will
100-
// raise. We need to handle this, probably converting the format with
101-
// libswresample.
102-
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
135+
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
136+
// may need to convert the wf into a supported output sample format, which is
137+
// what the `.sample_fmt` defines.
138+
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
103139

104140
int numChannels = static_cast<int>(wf_.sizes()[0]);
105141
TORCH_CHECK(
@@ -120,12 +156,6 @@ AudioEncoder::AudioEncoder(
120156
"avcodec_open2 failed: ",
121157
getFFMPEGErrorStringFromErrorCode(status));
122158

123-
TORCH_CHECK(
124-
avCodecContext_->frame_size > 0,
125-
"frame_size is ",
126-
avCodecContext_->frame_size,
127-
". Cannot encode. This should probably never happen?");
128-
129159
// We're allocating the stream here. Streams are meant to be freed by
130160
// avformat_free_context(avFormatContext), which we call in the
131161
// avFormatContext_'s destructor.
@@ -143,8 +173,11 @@ AudioEncoder::AudioEncoder(
143173
void AudioEncoder::encode() {
144174
UniqueAVFrame avFrame(av_frame_alloc());
145175
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
146-
avFrame->nb_samples = avCodecContext_->frame_size;
147-
avFrame->format = avCodecContext_->sample_fmt;
176+
// Default to 256 like in torchaudio
177+
int numSamplesAllocatedPerFrame =
178+
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
179+
avFrame->nb_samples = numSamplesAllocatedPerFrame;
180+
avFrame->format = AV_SAMPLE_FMT_FLTP;
148181
avFrame->sample_rate = avCodecContext_->sample_rate;
149182
avFrame->pts = 0;
150183
setChannelLayout(avFrame, avCodecContext_);
@@ -160,7 +193,6 @@ void AudioEncoder::encode() {
160193
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
161194
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
162195
int numEncodedSamples = 0; // per channel
163-
int numSamplesPerFrame = avCodecContext_->frame_size; // per channel
164196
int numBytesPerSample = static_cast<int>(wf_.element_size());
165197
int numBytesPerChannel = numSamples * numBytesPerSample;
166198

@@ -178,7 +210,7 @@ void AudioEncoder::encode() {
178210
getFFMPEGErrorStringFromErrorCode(status));
179211

180212
int numSamplesToEncode =
181-
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
213+
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
182214
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
183215

184216
for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
@@ -211,7 +243,37 @@ void AudioEncoder::encode() {
211243

212244
void AudioEncoder::encodeInnerLoop(
213245
AutoAVPacket& autoAVPacket,
214-
const UniqueAVFrame& avFrame) {
246+
const UniqueAVFrame& srcAVFrame) {
247+
bool mustConvert =
248+
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
249+
srcAVFrame != nullptr);
250+
UniqueAVFrame convertedAVFrame;
251+
if (mustConvert) {
252+
if (!swrContext_) {
253+
swrContext_.reset(createSwrContext(
254+
avCodecContext_,
255+
AV_SAMPLE_FMT_FLTP,
256+
avCodecContext_->sample_fmt,
257+
srcAVFrame->sample_rate, // No sample rate conversion
258+
srcAVFrame->sample_rate));
259+
}
260+
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
261+
swrContext_,
262+
srcAVFrame,
263+
avCodecContext_->sample_fmt,
264+
srcAVFrame->sample_rate, // No sample rate conversion
265+
srcAVFrame->sample_rate);
266+
TORCH_CHECK(
267+
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
268+
"convertedAVFrame->nb_samples=",
269+
convertedAVFrame->nb_samples,
270+
" differs from ",
271+
"srcAVFrame->nb_samples=",
272+
srcAVFrame->nb_samples,
273+
"This is unexpected, please report on the TorchCodec bug tracker.");
274+
}
275+
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
276+
215277
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
216278
TORCH_CHECK(
217279
status == AVSUCCESS,
@@ -248,6 +310,9 @@ void AudioEncoder::encodeInnerLoop(
248310
}
249311

250312
void AudioEncoder::flushBuffers() {
313+
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing
314+
// swresample is only necessary when converting sample rates, which we don't
315+
// do for encoding.
251316
AutoAVPacket autoAVPacket;
252317
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
253318
}

src/torchcodec/_core/Encoder.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@ class AudioEncoder {
2020
// encoding will still work but audio will be distorted.
2121
int sampleRate,
2222
std::string_view fileName,
23-
std::optional<int64_t> bit_rate = std::nullopt);
23+
std::optional<int64_t> bitRate = std::nullopt);
2424
void encode();
2525

2626
private:
2727
void encodeInnerLoop(
2828
AutoAVPacket& autoAVPacket,
29-
const UniqueAVFrame& avFrame);
29+
const UniqueAVFrame& srcAVFrame);
3030
void flushBuffers();
3131

3232
UniqueEncodingAVFormatContext avFormatContext_;
3333
UniqueAVCodecContext avCodecContext_;
3434
int streamIndex_;
35+
UniqueSwrContext swrContext_;
3536

3637
const torch::Tensor wf_;
3738
};

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,17 @@ void setChannelLayout(
116116
#endif
117117
}
118118

119-
SwrContext* allocateSwrContext(
119+
SwrContext* createSwrContext(
120120
UniqueAVCodecContext& avCodecContext,
121121
AVSampleFormat sourceSampleFormat,
122122
AVSampleFormat desiredSampleFormat,
123123
int sourceSampleRate,
124124
int desiredSampleRate) {
125125
SwrContext* swrContext = nullptr;
126+
int status = AVSUCCESS;
126127
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
127128
AVChannelLayout layout = avCodecContext->ch_layout;
128-
auto status = swr_alloc_set_opts2(
129+
status = swr_alloc_set_opts2(
129130
&swrContext,
130131
&layout,
131132
desiredSampleFormat,
@@ -155,9 +156,77 @@ SwrContext* allocateSwrContext(
155156
#endif
156157

157158
TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext");
159+
status = swr_init(swrContext);
160+
TORCH_CHECK(
161+
status == AVSUCCESS,
162+
"Couldn't initialize SwrContext: ",
163+
getFFMPEGErrorStringFromErrorCode(status),
164+
". If the error says 'Invalid argument', it's likely that you are using "
165+
"a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
166+
"valid scenarios. Try to upgrade FFmpeg?");
158167
return swrContext;
159168
}
160169

170+
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
171+
const UniqueSwrContext& swrContext,
172+
const UniqueAVFrame& srcAVFrame,
173+
AVSampleFormat desiredSampleFormat,
174+
int sourceSampleRate,
175+
int desiredSampleRate) {
176+
UniqueAVFrame convertedAVFrame(av_frame_alloc());
177+
TORCH_CHECK(
178+
convertedAVFrame,
179+
"Could not allocate frame for sample format conversion.");
180+
181+
setChannelLayout(convertedAVFrame, srcAVFrame);
182+
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
183+
convertedAVFrame->sample_rate = desiredSampleRate;
184+
if (sourceSampleRate != desiredSampleRate) {
185+
// Note that this is an upper bound on the number of output samples.
186+
// `swr_convert()` will likely not fill convertedAVFrame with that many
187+
// samples if sample rate conversion is needed. It will buffer the last few
188+
// ones because those require future samples. That's also why we reset
189+
// nb_samples after the call to `swr_convert()`.
190+
// We could also use `swr_get_out_samples()` to determine the number of
191+
// output samples, but empirically `av_rescale_rnd()` seems to provide a
192+
// tighter bound.
193+
convertedAVFrame->nb_samples = av_rescale_rnd(
194+
swr_get_delay(swrContext.get(), sourceSampleRate) +
195+
srcAVFrame->nb_samples,
196+
desiredSampleRate,
197+
sourceSampleRate,
198+
AV_ROUND_UP);
199+
} else {
200+
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
201+
}
202+
203+
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
204+
TORCH_CHECK(
205+
status == AVSUCCESS,
206+
"Could not allocate frame buffers for sample format conversion: ",
207+
getFFMPEGErrorStringFromErrorCode(status));
208+
209+
auto numConvertedSamples = swr_convert(
210+
swrContext.get(),
211+
convertedAVFrame->data,
212+
convertedAVFrame->nb_samples,
213+
static_cast<const uint8_t**>(
214+
const_cast<const uint8_t**>(srcAVFrame->data)),
215+
srcAVFrame->nb_samples);
216+
// numConvertedSamples can be 0 if we're downsampling by a great factor and
217+
// the first frame doesn't contain a lot of samples. It should be handled
218+
// properly by the caller.
219+
TORCH_CHECK(
220+
numConvertedSamples >= 0,
221+
"Error in swr_convert: ",
222+
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));
223+
224+
// See comment above about nb_samples
225+
convertedAVFrame->nb_samples = numConvertedSamples;
226+
227+
return convertedAVFrame;
228+
}
229+
161230
void setFFmpegLogLevel() {
162231
auto logLevel = AV_LOG_QUIET;
163232
const char* logLevelEnvPtr = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL");

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,20 @@ void setChannelLayout(
158158
void setChannelLayout(
159159
UniqueAVFrame& dstAVFrame,
160160
const UniqueAVFrame& srcAVFrame);
161-
SwrContext* allocateSwrContext(
161+
SwrContext* createSwrContext(
162162
UniqueAVCodecContext& avCodecContext,
163163
AVSampleFormat sourceSampleFormat,
164164
AVSampleFormat desiredSampleFormat,
165165
int sourceSampleRate,
166166
int desiredSampleRate);
167167

168+
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
169+
const UniqueSwrContext& swrContext,
170+
const UniqueAVFrame& srcAVFrame,
171+
AVSampleFormat desiredSampleFormat,
172+
int sourceSampleRate,
173+
int desiredSampleRate);
174+
168175
// Returns true if sws_scale can handle unaligned data.
169176
bool canSwsScaleHandleUnalignedData();
170177

0 commit comments

Comments
 (0)