Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 5a37b94

Browse files
alebzkCommit Bot
authored andcommitted
Reland "RNN VAD: pitch search optimizations (part 2)"
This reverts commit e6a731f. Reason for revert: bug in parent CL fixed Original change's description: > Revert "RNN VAD: pitch search optimizations (part 2)" > > This reverts commit 2f7d1c6. > > Reason for revert: bug in ancestor CL https://webrtc-review.googlesource.com/c/src/+/191320 > > Original change's description: > > RNN VAD: pitch search optimizations (part 2) > > > > This CL brings a large improvement to the VAD by precomputing the > > energy for the sliding frame `y` in the pitch buffer instead of > > computing them twice in two different places. The realtime factor > > has improved by about +16x. > > > > There is room for additional improvement (TODOs added), but that will > > be done in a follow up CL since the change won't be bit-exact and > > careful testing is needed. > > > > Benchmarked as follows: > > ``` > > out/release/modules_unittests \ > > --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \ > > --gtest_also_run_disabled_tests --logs > > ``` > > > > Results: > > > > | baseline | this CL > > ------+----------------------+------------------------ > > run 1 | 23.568 +/- 0.990788 | 22.8319 +/- 1.46554 > > | 377.207x | 389.367x > > ------+----------------------+------------------------ > > run 2 | 23.3714 +/- 0.857523 | 22.4286 +/- 0.726449 > > | 380.379x | 396.369x > > ------+----------------------+------------------------ > > run 2 | 23.709 +/- 1.04477 | 22.5688 +/- 0.831341 > > | 374.963x | 393.906x > > > > Bug: webrtc:10480 > > Change-Id: I599a4dda2bde16dc6c2f42cf89e96afbd4630311 > > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191484 > > Reviewed-by: Per Åhgren <[email protected]> > > Commit-Queue: Alessio Bazzica <[email protected]> > > Cr-Commit-Position: refs/heads/master@{#32571} > > [email protected],[email protected] > > Change-Id: I53e478d8d58912c7a5fae4ad8a8d1342a9a48091 > No-Presubmit: true > No-Tree-Checks: true > No-Try: true > Bug: webrtc:10480 > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192620 > Reviewed-by: Alessio Bazzica <[email protected]> > Commit-Queue: Alessio Bazzica <[email protected]> > Cr-Commit-Position: refs/heads/master@{#32580} [email protected],[email protected] # Not skipping CQ checks because this is a reland. Bug: webrtc:10480 Change-Id: I0d6c89c64587bb6c38e69b968df12a5eb499ac6f Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192782 Commit-Queue: Alessio Bazzica <[email protected]> Reviewed-by: Alessio Bazzica <[email protected]> Reviewed-by: Per Åhgren <[email protected]> Cr-Commit-Position: refs/heads/master@{#32586}
1 parent c36f862 commit 5a37b94

File tree

5 files changed

+94
-79
lines changed

5 files changed

+94
-79
lines changed

modules/audio_processing/agc2/rnn_vad/pitch_search.cc

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,37 +19,46 @@ namespace webrtc {
1919
namespace rnn_vad {
2020

2121
PitchEstimator::PitchEstimator()
22-
: pitch_buf_decimated_(kBufSize12kHz),
23-
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
24-
auto_corr_(kNumLags12kHz),
25-
auto_corr_view_(auto_corr_.data(), kNumLags12kHz) {
26-
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
27-
RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size());
28-
}
22+
: y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
23+
pitch_buffer_12kHz_(kBufSize12kHz),
24+
auto_correlation_12kHz_(kNumLags12kHz) {}
2925

3026
PitchEstimator::~PitchEstimator() = default;
3127

3228
int PitchEstimator::Estimate(
3329
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
30+
rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
31+
pitch_buffer_12kHz_.data(), kBufSize12kHz);
32+
RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
33+
rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
34+
auto_correlation_12kHz_.data(), kNumLags12kHz);
35+
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
36+
auto_correlation_12kHz_view.size());
37+
3438
// Perform the initial pitch search at 12 kHz.
35-
Decimate2x(pitch_buffer, pitch_buf_decimated_view_);
36-
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
37-
auto_corr_view_);
38-
CandidatePitchPeriods pitch_candidates_inverted_lags =
39-
ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_);
40-
// Refine the pitch period estimation.
39+
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
40+
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
41+
auto_correlation_12kHz_view);
42+
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
43+
pitch_buffer_12kHz_view, auto_correlation_12kHz_view);
4144
// The refinement is done using the pitch buffer that contains 24 kHz samples.
4245
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
4346
// to 24 kHz.
44-
pitch_candidates_inverted_lags.best *= 2;
45-
pitch_candidates_inverted_lags.second_best *= 2;
46-
const int pitch_inv_lag_48kHz =
47-
ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags);
48-
// Look for stronger harmonics to find the final pitch period and its gain.
49-
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
47+
pitch_periods.best *= 2;
48+
pitch_periods.second_best *= 2;
49+
50+
// Refine the initial pitch period estimation from 12 kHz to 48 kHz.
51+
// Pre-compute frame energies at 24 kHz.
52+
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
53+
y_energy_24kHz_.data(), kRefineNumLags24kHz);
54+
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
55+
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view);
56+
// Estimation at 48 kHz.
57+
const int pitch_lag_48kHz =
58+
ComputePitchPeriod48kHz(pitch_buffer, y_energy_24kHz_view, pitch_periods);
5059
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
51-
pitch_buffer,
52-
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz,
60+
pitch_buffer, y_energy_24kHz_view,
61+
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
5362
last_pitch_48kHz_);
5463
return last_pitch_48kHz_.period;
5564
}

modules/audio_processing/agc2/rnn_vad/pitch_search.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,9 @@ class PitchEstimator {
4141

4242
PitchInfo last_pitch_48kHz_{};
4343
AutoCorrelationCalculator auto_corr_calculator_;
44-
std::vector<float> pitch_buf_decimated_;
45-
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
46-
std::vector<float> auto_corr_;
47-
rtc::ArrayView<float, kNumLags12kHz> auto_corr_view_;
44+
std::vector<float> y_energy_24kHz_;
45+
std::vector<float> pitch_buffer_12kHz_;
46+
std::vector<float> auto_correlation_12kHz_;
4847
};
4948

5049
} // namespace rnn_vad

modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -156,41 +156,30 @@ void ComputeAutoCorrelation(
156156
}
157157
}
158158

159-
int FindBestPitchPeriods24kHz(
159+
int ComputePitchPeriod24kHz(
160+
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
160161
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
161-
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
162+
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
162163
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
163164
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
164-
// Initialize the sliding 20 ms frame energy.
165-
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
166-
float denominator = std::inner_product(
167-
pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms24kHz + 1,
168-
pitch_buffer.begin(), 1.f);
169-
// Search best pitch by looking at the scaled auto-correlation.
170165
int best_inverted_lag = 0; // Pitch period.
171166
float best_numerator = -1.f; // Pitch strength numerator.
172167
float best_denominator = 0.f; // Pitch strength denominator.
173168
for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz;
174169
++inverted_lag) {
175170
// A pitch candidate must have positive correlation.
176171
if (auto_correlation[inverted_lag] > 0.f) {
172+
// Auto-correlation energy normalized by frame energy.
177173
const float numerator =
178174
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
175+
const float denominator = y_energy[kMaxPitch24kHz - inverted_lag];
179176
// Compare numerator/denominator ratios without using divisions.
180177
if (numerator * best_denominator > best_numerator * denominator) {
181178
best_inverted_lag = inverted_lag;
182179
best_numerator = numerator;
183180
best_denominator = denominator;
184181
}
185182
}
186-
// Update |denominator| for the next inverted lag.
187-
static_assert(kInitialNumLags24kHz + kFrameSize20ms24kHz < kBufSize24kHz,
188-
"");
189-
const float y_old = pitch_buffer[inverted_lag];
190-
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
191-
denominator -= y_old * y_old;
192-
denominator += y_new * y_new;
193-
denominator = std::max(0.f, denominator);
194183
}
195184
return best_inverted_lag;
196185
}
@@ -341,6 +330,7 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
341330

342331
int ComputePitchPeriod48kHz(
343332
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
333+
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
344334
CandidatePitchPeriods pitch_candidates) {
345335
// Compute the auto-correlation terms only for neighbors of the given pitch
346336
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
@@ -369,14 +359,15 @@ int ComputePitchPeriod48kHz(
369359
}
370360
// Find best pitch at 24 kHz.
371361
const int pitch_candidate_24kHz =
372-
FindBestPitchPeriods24kHz(auto_correlation, pitch_buffer);
362+
ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy);
373363
// Pseudo-interpolation.
374364
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
375365
auto_correlation);
376366
}
377367

378368
PitchInfo ComputeExtendedPitchPeriod48kHz(
379369
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
370+
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
380371
int initial_pitch_period_48kHz,
381372
PitchInfo last_pitch_48kHz) {
382373
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
@@ -386,34 +377,30 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
386377
struct RefinedPitchCandidate {
387378
int period;
388379
float strength;
389-
// Additional strength data used for the final estimation of the strength.
390-
float xy; // Cross-correlation.
391-
float yy; // Auto-correlation.
380+
// Additional strength data used for the final pitch estimation.
381+
float xy; // Auto-correlation.
382+
float y_energy; // Energy of the sliding frame `y`.
392383
};
393384

394-
// Initialize.
395-
std::array<float, kRefineNumLags24kHz> yy_values;
396-
// TODO(bugs.webrtc.org/9076): Reuse values from FindBestPitchPeriods24kHz().
397-
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, yy_values);
398-
const float xx = yy_values[0];
399-
const auto pitch_strength = [](float xy, float yy, float xx) {
400-
RTC_DCHECK_GE(xx * yy, 0.f);
401-
return xy / std::sqrt(1.f + xx * yy);
385+
const float x_energy = y_energy[0];
386+
const auto pitch_strength = [x_energy](float xy, float y_energy) {
387+
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
388+
return xy / std::sqrt(1.f + x_energy * y_energy);
402389
};
403-
// Initial pitch candidate.
390+
391+
// Initialize the best pitch candidate with `initial_pitch_period_48kHz`.
404392
RefinedPitchCandidate best_pitch;
405393
best_pitch.period =
406394
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
407395
best_pitch.xy =
408396
ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer);
409-
best_pitch.yy = yy_values[best_pitch.period];
410-
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.yy, xx);
411-
412-
// 24 kHz version of the last estimated pitch and copy of the initial
413-
// estimation.
397+
best_pitch.y_energy = y_energy[best_pitch.period];
398+
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
399+
// Keep a copy of the initial pitch candidate.
400+
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
401+
// 24 kHz version of the last estimated pitch.
414402
const PitchInfo last_pitch{last_pitch_48kHz.period / 2,
415403
last_pitch_48kHz.strength};
416-
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
417404

418405
// Find `max_period_divisor` such that the result of
419406
// `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)`
@@ -443,14 +430,14 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
443430
// Compute an auto-correlation score for the primary pitch candidate
444431
// |alternative_pitch.period| by also looking at its possible sub-harmonic
445432
// |dual_alternative_period|.
446-
float xy_primary_period = ComputeAutoCorrelation(
433+
const float xy_primary_period = ComputeAutoCorrelation(
447434
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer);
448-
float xy_secondary_period = ComputeAutoCorrelation(
435+
const float xy_secondary_period = ComputeAutoCorrelation(
449436
kMaxPitch24kHz - dual_alternative_period, pitch_buffer);
450-
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
451-
float yy = 0.5f * (yy_values[alternative_pitch.period] +
452-
yy_values[dual_alternative_period]);
453-
alternative_pitch.strength = pitch_strength(xy, yy, xx);
437+
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
438+
const float yy = 0.5f * (y_energy[alternative_pitch.period] +
439+
y_energy[dual_alternative_period]);
440+
alternative_pitch.strength = pitch_strength(xy, yy);
454441

455442
// Maybe update best period.
456443
if (IsAlternativePitchStrongerThanInitial(
@@ -462,10 +449,11 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
462449

463450
// Final pitch strength and period.
464451
best_pitch.xy = std::max(0.f, best_pitch.xy);
465-
RTC_DCHECK_LE(0.f, best_pitch.yy);
466-
float final_pitch_strength = (best_pitch.yy <= best_pitch.xy)
467-
? 1.f
468-
: best_pitch.xy / (best_pitch.yy + 1.f);
452+
RTC_DCHECK_LE(0.f, best_pitch.y_energy);
453+
float final_pitch_strength =
454+
(best_pitch.y_energy <= best_pitch.xy)
455+
? 1.f
456+
: best_pitch.xy / (best_pitch.y_energy + 1.f);
469457
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
470458
int final_pitch_period_48kHz = std::max(
471459
kMinPitch48kHz,

modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
8080
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
8181
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation);
8282

83-
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer
84-
// and the pitch period candidates at 24 kHz (encoded as inverted lag).
83+
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer,
84+
// the energies for the sliding frames `y` at 24 kHz and the pitch period
85+
// candidates at 24 kHz (encoded as inverted lag).
8586
int ComputePitchPeriod48kHz(
8687
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
88+
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
8789
CandidatePitchPeriods pitch_candidates_24kHz);
8890

8991
struct PitchInfo {
@@ -92,10 +94,12 @@ struct PitchInfo {
9294
};
9395

9496
// Computes the pitch period at 48 kHz searching in an extended pitch range
95-
// given a view on the 24 kHz pitch buffer, the initial 48 kHz estimation
96-
// (computed by `ComputePitchPeriod48kHz()`) and the last estimated pitch.
97+
// given a view on the 24 kHz pitch buffer, the energies for the sliding frames
98+
// `y` at 24 kHz, the initial 48 kHz estimation (computed by
99+
// `ComputePitchPeriod48kHz()`) and the last estimated pitch.
97100
PitchInfo ComputeExtendedPitchPeriod48kHz(
98101
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
102+
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
99103
int initial_pitch_period_48kHz,
100104
PitchInfo last_pitch_48kHz);
101105

modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,17 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
6363
// Checks that the refined pitch period is bit-exact given test input data.
6464
TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
6565
PitchTestData test_data;
66+
std::vector<float> y_energy(kRefineNumLags24kHz);
67+
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
68+
kRefineNumLags24kHz);
69+
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
70+
y_energy_view);
6671
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
6772
// FloatingPointExceptionObserver fpe_observer;
68-
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
73+
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
6974
/*pitch_candidates=*/{280, 284}),
7075
560);
71-
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
76+
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
7277
/*pitch_candidates=*/{260, 284}),
7378
568);
7479
}
@@ -88,9 +93,14 @@ class PitchCandidatesParametrization
8893
TEST_P(PitchCandidatesParametrization,
8994
ComputePitchPeriod48kHzOrderDoesNotMatter) {
9095
PitchTestData test_data;
91-
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
96+
std::vector<float> y_energy(kRefineNumLags24kHz);
97+
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
98+
kRefineNumLags24kHz);
99+
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
100+
y_energy_view);
101+
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
92102
GetPitchCandidates()),
93-
ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
103+
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
94104
GetSwappedPitchCandidates()));
95105
}
96106

@@ -118,10 +128,15 @@ class ExtendedPitchPeriodSearchParametrizaion
118128
TEST_P(ExtendedPitchPeriodSearchParametrizaion,
119129
PeriodBitExactnessGainWithinTolerance) {
120130
PitchTestData test_data;
131+
std::vector<float> y_energy(kMaxPitch24kHz + 1);
132+
rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
133+
kMaxPitch24kHz + 1);
134+
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
135+
y_energy_view);
121136
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
122137
// FloatingPointExceptionObserver fpe_observer;
123138
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
124-
test_data.GetPitchBufView(), GetInitialPitchPeriod(),
139+
test_data.GetPitchBufView(), y_energy_view, GetInitialPitchPeriod(),
125140
{GetLastPitchPeriod(), GetLastPitchStrength()});
126141
EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period);
127142
EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f);

0 commit comments

Comments
 (0)