Skip to content

Commit 4be2953

Browse files
committed
Add flushing logic for swresample buffers
1 parent 823e7f0 commit 4be2953

File tree

3 files changed

+79
-27
lines changed

3 files changed

+79
-27
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,23 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
9393
return avCodec.sample_fmts[0];
9494
}
9595

96+
UniqueAVFrame allocateAVFrame(int numSamples, int sampleRate, int numChannels) {
97+
auto avFrame = UniqueAVFrame(av_frame_alloc());
98+
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
99+
100+
avFrame->nb_samples = numSamples;
101+
avFrame->format = AV_SAMPLE_FMT_FLTP;
102+
avFrame->sample_rate = sampleRate;
103+
av_channel_layout_default(&avFrame->ch_layout, numChannels);
104+
auto status = av_frame_get_buffer(avFrame.get(), 0);
105+
TORCH_CHECK(
106+
status == AVSUCCESS,
107+
"Couldn't allocate avFrame's buffers: ",
108+
getFFMPEGErrorStringFromErrorCode(status));
109+
110+
return avFrame;
111+
}
112+
96113
} // namespace
97114

98115
AudioEncoder::~AudioEncoder() {}
@@ -228,24 +245,14 @@ void AudioEncoder::encode() {
228245
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
229246
encodeWasCalled_ = true;
230247

231-
UniqueAVFrame avFrame(av_frame_alloc());
232-
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
233248
// Default to 256 like in torchaudio
234249
int numSamplesAllocatedPerFrame =
235250
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
236-
avFrame->nb_samples = numSamplesAllocatedPerFrame;
237-
avFrame->format = AV_SAMPLE_FMT_FLTP;
238-
avFrame->sample_rate = sampleRateInput_;
251+
UniqueAVFrame avFrame = allocateAVFrame(
252+
numSamplesAllocatedPerFrame,
253+
sampleRateInput_,
254+
static_cast<int>(wf_.sizes()[0]));
239255
avFrame->pts = 0;
240-
// We set the channel layout of the frame to the default layout corresponding
241-
// to the input samples' number of channels
242-
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
243-
244-
auto status = av_frame_get_buffer(avFrame.get(), 0);
245-
TORCH_CHECK(
246-
status == AVSUCCESS,
247-
"Couldn't allocate avFrame's buffers: ",
248-
getFFMPEGErrorStringFromErrorCode(status));
249256

250257
AutoAVPacket autoAVPacket;
251258

@@ -255,7 +262,7 @@ void AudioEncoder::encode() {
255262
int numBytesPerSample = static_cast<int>(wf_.element_size());
256263
int numBytesPerChannel = numSamples * numBytesPerSample;
257264

258-
status = avformat_write_header(avFormatContext_.get(), nullptr);
265+
auto status = avformat_write_header(avFormatContext_.get(), nullptr);
259266
TORCH_CHECK(
260267
status == AVSUCCESS,
261268
"Error in avformat_write_header: ",
@@ -302,10 +309,14 @@ void AudioEncoder::encode() {
302309

303310
void AudioEncoder::encodeInnerLoop(
304311
AutoAVPacket& autoAVPacket,
305-
const UniqueAVFrame& srcAVFrame) {
312+
const UniqueAVFrame& srcAVFrame,
313+
bool allowConvert) {
314+
// TODO: Probably makes more sense to move the conversion away? It shouldn't
315+
// be in inner loop in any case. We should also remove allowConvert.
306316
bool mustConvert =
307-
(srcAVFrame != nullptr &&
308-
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
317+
(allowConvert && srcAVFrame != nullptr &&
318+
(static_cast<AVSampleFormat>(srcAVFrame->format) !=
319+
avCodecContext_->sample_fmt ||
309320
getNumChannels(srcAVFrame) != outNumChannels_ ||
310321
srcAVFrame->sample_rate != outSampleRate_));
311322

@@ -377,10 +388,31 @@ void AudioEncoder::encodeInnerLoop(
377388
}
378389
}
379390

391+
void AudioEncoder::maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket) {
392+
// Similar to the decoder's method with the same name, but for encoding this
393+
// time. That is, when sample conversion is invovled, libswresample may have
394+
// buffered some samples that we now need to flush and send to the encoder.
395+
if (swrContext_ == nullptr && sampleRateInput_ == outSampleRate_) {
396+
return;
397+
}
398+
int numRemainingSamples = // this is an upper bound
399+
swr_get_out_samples(swrContext_.get(), 0);
400+
if (numRemainingSamples == 0) {
401+
return;
402+
}
403+
404+
UniqueAVFrame avFrame =
405+
allocateAVFrame(numRemainingSamples, outSampleRate_, outNumChannels_);
406+
int actualNumRemainingSamples = swr_convert(
407+
swrContext_.get(), avFrame->data, avFrame->nb_samples, NULL, 0);
408+
avFrame->nb_samples = actualNumRemainingSamples;
409+
410+
encodeInnerLoop(autoAVPacket, avFrame, false);
411+
}
412+
380413
void AudioEncoder::flushBuffers() {
381-
// TODO Need to fluh libwresample buffers since we may be doing sample
382-
// rate conversion!!!
383414
AutoAVPacket autoAVPacket;
415+
maybeFlushSwrBuffers(autoAVPacket);
384416
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
385417
}
386418
} // namespace facebook::torchcodec

src/torchcodec/_core/Encoder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ class AudioEncoder {
3838
void initializeEncoder(const AudioStreamOptions& audioStreamOptions);
3939
void encodeInnerLoop(
4040
AutoAVPacket& autoAVPacket,
41-
const UniqueAVFrame& srcAVFrame);
41+
const UniqueAVFrame& srcAVFrame,
42+
bool allowConvert = true);
43+
void maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket);
4244
void flushBuffers();
4345

4446
UniqueEncodingAVFormatContext avFormatContext_;

test/test_encoders.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,23 @@ def test_round_trip(self, method, format, tmp_path):
118118
)
119119

120120
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
121-
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
122-
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
123-
@pytest.mark.parametrize("num_channels", (None, 1, 2))
124-
@pytest.mark.parametrize("format", ("mp3", "wav", "flac"))
121+
# @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
122+
@pytest.mark.parametrize("asset", (SINE_MONO_S32,))
123+
# @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3,))
124+
# @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
125+
@pytest.mark.parametrize("bit_rate", (None,))
126+
# @pytest.mark.parametrize("num_channels", (None, 1, 2))
127+
@pytest.mark.parametrize("num_channels", (None,))
128+
# @pytest.mark.parametrize("sample_rate", (None, 32_000))
129+
# @pytest.mark.parametrize("sample_rate", (32_000,))
130+
@pytest.mark.parametrize("sample_rate", (8_000, 32_000))
131+
# @pytest.mark.parametrize("format", ("mp3", "wav", "flac"))
132+
@pytest.mark.parametrize("format", ("wav",))
125133
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
126-
def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_path):
134+
# @pytest.mark.parametrize("method", ("to_file",))#, "to_tensor"))
135+
def test_against_cli(
136+
self, asset, bit_rate, num_channels, sample_rate, format, method, tmp_path
137+
):
127138
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
128139
# that both decoded outputs are equal
129140

@@ -135,6 +146,7 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
135146
["ffmpeg", "-i", str(asset.path)]
136147
+ (["-b:a", f"{bit_rate}"] if bit_rate is not None else [])
137148
+ (["-ac", f"{num_channels}"] if num_channels is not None else [])
149+
+ (["-ar", f"{sample_rate}"] if sample_rate is not None else [])
138150
+ [
139151
str(encoded_by_ffmpeg),
140152
],
@@ -143,7 +155,9 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
143155
)
144156

145157
encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate)
146-
params = dict(bit_rate=bit_rate, num_channels=num_channels)
158+
params = dict(
159+
bit_rate=bit_rate, num_channels=num_channels, sample_rate=sample_rate
160+
)
147161
if method == "to_file":
148162
encoded_by_us = tmp_path / f"output.{format}"
149163
encoder.to_file(dest=str(encoded_by_us), **params)
@@ -161,6 +175,10 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
161175
else:
162176
rtol, atol = None, None
163177
torch.testing.assert_close(
178+
# self.decode(encoded_by_ffmpeg)[:, :-100],
179+
# self.decode(encoded_by_us)[:, :-100],
180+
# self.decode(encoded_by_ffmpeg)[:, :-32],
181+
# self.decode(encoded_by_us)[:, :-32],
164182
self.decode(encoded_by_ffmpeg),
165183
self.decode(encoded_by_us),
166184
rtol=rtol,

0 commit comments

Comments
 (0)