Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit f2a2fe8

Browse files
alebzkCommit Bot
authored andcommitted
Reland "RNN VAD: pitch search optimizations (part 3)"
This reverts commit 57e68ee. Reason for revert: bug in ancestor CL fixed Original change's description: > Revert "RNN VAD: pitch search optimizations (part 3)" > > This reverts commit ea89f2a. > > Reason for revert: bug in ancestor CL https://webrtc-review.googlesource.com/c/src/+/191320 > > Original change's description: > > RNN VAD: pitch search optimizations (part 3) > > > > `ComputeSlidingFrameSquareEnergies()` which computes the energy of a > > sliding 20 ms frame in the pitch buffer has been switched from backward > > to forward. > > > > The benchmark has shown a slight improvement (about +6x). > > > > This change is not bit exact but all the tolerance tests still pass > > except for one single case in `RnnVadTest,PitchSearchWithinTolerance` > > for which the tolerance has been slightly increased. Note that the pitch > > estimation is still bit-exact. > > > > Benchmarked as follows: > > ``` > > out/release/modules_unittests \ > > --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \ > > --gtest_also_run_disabled_tests --logs > > ``` > > > > Results: > > > > | baseline | this CL > > ------+----------------------+------------------------ > > run 1 | 22.8319 +/- 1.46554 | 22.087 +/- 0.552932 > > | 389.367x | 402.499x > > ------+----------------------+------------------------ > > run 2 | 22.4286 +/- 0.726449 | 22.216 +/- 0.916222 > > | 396.369x | 400.162x > > ------+----------------------+------------------------ > > run 2 | 22.5688 +/- 0.831341 | 22.4902 +/- 1.04881 > > | 393.906x | 395.283x > > > > Bug: webrtc:10480 > > Change-Id: I1fd54077a32e25e46196c8e18f003cd0ffd503e1 > > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191703 > > Commit-Queue: Alessio Bazzica <[email protected]> > > Reviewed-by: Karl Wiberg <[email protected]> > > Cr-Commit-Position: refs/heads/master@{#32572} > > [email protected],[email protected] > > Change-Id: I57a8f937ade0a35e1ccf0e229c391cc3a10e7c48 > No-Presubmit: true > No-Tree-Checks: true > No-Try: true > Bug: webrtc:10480 > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192621 > Reviewed-by: Alessio Bazzica <[email protected]> > Commit-Queue: Alessio Bazzica <[email protected]> > Cr-Commit-Position: refs/heads/master@{#32578} [email protected],[email protected] # Not skipping CQ checks because this is a reland. Bug: webrtc:10480 Change-Id: I1d510697236255d8c0cca405e90781f5d8c6a3e6 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192783 Commit-Queue: Alessio Bazzica <[email protected]> Reviewed-by: Alessio Bazzica <[email protected]> Reviewed-by: Karl Wiberg <[email protected]> Cr-Commit-Position: refs/heads/master@{#32587}
1 parent 5a37b94 commit f2a2fe8

File tree

5 files changed

+27
-21
lines changed

5 files changed

+27
-21
lines changed

modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ int ComputePitchPeriod24kHz(
172172
// Auto-correlation energy normalized by frame energy.
173173
const float numerator =
174174
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
175-
const float denominator = y_energy[kMaxPitch24kHz - inverted_lag];
175+
const float denominator = y_energy[inverted_lag];
176176
// Compare numerator/denominator ratios without using divisions.
177177
if (numerator * best_denominator > best_numerator * denominator) {
178178
best_inverted_lag = inverted_lag;
@@ -256,19 +256,19 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
256256

257257
void ComputeSlidingFrameSquareEnergies24kHz(
258258
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
259-
rtc::ArrayView<float, kRefineNumLags24kHz> yy_values) {
260-
float yy = ComputeAutoCorrelation(kMaxPitch24kHz, pitch_buffer);
261-
yy_values[0] = yy;
262-
static_assert(kMaxPitch24kHz - (kRefineNumLags24kHz - 1) >= 0, "");
259+
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy) {
260+
float yy = std::inner_product(pitch_buffer.begin(),
261+
pitch_buffer.begin() + kFrameSize20ms24kHz,
262+
pitch_buffer.begin(), 0.f);
263+
y_energy[0] = yy;
263264
static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, "");
264-
for (int lag = 1; lag < kRefineNumLags24kHz; ++lag) {
265-
const int inverted_lag = kMaxPitch24kHz - lag;
266-
const float y_old = pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
267-
const float y_new = pitch_buffer[inverted_lag];
268-
yy -= y_old * y_old;
269-
yy += y_new * y_new;
270-
yy = std::max(0.f, yy);
271-
yy_values[lag] = yy;
265+
static_assert(kMaxPitch24kHz < kRefineNumLags24kHz, "");
266+
for (int inverted_lag = 0; inverted_lag < kMaxPitch24kHz; ++inverted_lag) {
267+
yy -= pitch_buffer[inverted_lag] * pitch_buffer[inverted_lag];
268+
yy += pitch_buffer[inverted_lag + kFrameSize20ms24kHz] *
269+
pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
270+
yy = std::max(1.f, yy);
271+
y_energy[inverted_lag + 1] = yy;
272272
}
273273
}
274274

@@ -382,7 +382,7 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
382382
float y_energy; // Energy of the sliding frame `y`.
383383
};
384384

385-
const float x_energy = y_energy[0];
385+
const float x_energy = y_energy[kMaxPitch24kHz];
386386
const auto pitch_strength = [x_energy](float xy, float y_energy) {
387387
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
388388
return xy / std::sqrt(1.f + x_energy * y_energy);
@@ -394,7 +394,7 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
394394
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
395395
best_pitch.xy =
396396
ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer);
397-
best_pitch.y_energy = y_energy[best_pitch.period];
397+
best_pitch.y_energy = y_energy[kMaxPitch24kHz - best_pitch.period];
398398
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
399399
// Keep a copy of the initial pitch candidate.
400400
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
@@ -435,8 +435,9 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
435435
const float xy_secondary_period = ComputeAutoCorrelation(
436436
kMaxPitch24kHz - dual_alternative_period, pitch_buffer);
437437
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
438-
const float yy = 0.5f * (y_energy[alternative_pitch.period] +
439-
y_energy[dual_alternative_period]);
438+
const float yy =
439+
0.5f * (y_energy[kMaxPitch24kHz - alternative_pitch.period] +
440+
y_energy[kMaxPitch24kHz - dual_alternative_period]);
440441
alternative_pitch.strength = pitch_strength(xy, yy);
441442

442443
// Maybe update best period.

modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
6262
// corresponding pitch period.
6363

6464
// Computes the sum of squared samples for every sliding frame `y` in the pitch
65-
// buffer. The indexes of `yy_values` are lags.
65+
// buffer. The indexes of `y_energy` are inverted lags.
6666
void ComputeSlidingFrameSquareEnergies24kHz(
6767
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
68-
rtc::ArrayView<float, kRefineNumLags24kHz> yy_values);
68+
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy);
6969

7070
// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags.
7171
struct CandidatePitchPeriods {

modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
4242
computed_output);
4343
auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
4444
ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
45-
computed_output, 3e-2f);
45+
computed_output, 1e-3f);
4646
}
4747

4848
// Checks that the estimated pitch period is bit-exact given test input data.

modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) {
4242
pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
4343
EXPECT_EQ(expected_pitch_period, pitch_period);
4444
EXPECT_NEAR(expected_pitch_strength,
45-
pitch_estimator.GetLastPitchStrengthForTesting(), 1e-5f);
45+
pitch_estimator.GetLastPitchStrengthForTesting(), 15e-6f);
4646
}
4747
}
4848
}

modules/audio_processing/agc2/rnn_vad/test_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
1212

13+
#include <algorithm>
1314
#include <memory>
1415

1516
#include "rtc_base/checks.h"
@@ -86,6 +87,10 @@ PitchTestData::PitchTestData() {
8687
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
8788
1396);
8889
test_data_reader.ReadChunk(test_data_);
90+
// Reverse the order of the squared energy values.
91+
// Required after the WebRTC CL 191703 which switched to forward computation.
92+
std::reverse(test_data_.begin() + kBufSize24kHz,
93+
test_data_.begin() + kBufSize24kHz + kNumPitchBufSquareEnergies);
8994
}
9095

9196
PitchTestData::~PitchTestData() = default;

0 commit comments

Comments
 (0)