Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 05f5d63

Browse files
alebzkCommit Bot
authored andcommitted
RNN VAD: pitch search optimizations (part 4)
Add inverted lags index to simplify the loop in `FindBestPitchPeriod48kHz()`. Instead of looping over 294 items, only loop over the relevant ones (up to 10) by keeping track of the relevant indexes. The benchmark has shown a slight improvement (about +6x). Benchmarked as follows: ``` out/release/modules_unittests \ --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \ --gtest_also_run_disabled_tests --logs ``` Results: | baseline | this CL ------+----------------------+------------------------ run 1 | 22.8319 +/- 1.46554 | 22.1951 +/- 0.747611 | 389.367x | 400.539x ------+----------------------+------------------------ run 2 | 22.4286 +/- 0.726449 | 22.2718 +/- 0.963738 | 396.369x | 399.16x ------+----------------------+------------------------ run 2 | 22.5688 +/- 0.831341 | 22.4166 +/- 0.953362 | 393.906x | 396.581x This CL also moved `PitchPseudoInterpolationInvLagAutoCorr()` into `FindBestPitchPeriod48kHz()`. Bug: webrtc:10480 Change-Id: Id4e6d755045c3198a80fa94a0a7463577d909b7e Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191764 Commit-Queue: Alessio Bazzica <[email protected]> Reviewed-by: Karl Wiberg <[email protected]> Cr-Commit-Position: refs/heads/master@{#32590}
1 parent ccbc216 commit 05f5d63

File tree

2 files changed

+76
-44
lines changed

2 files changed

+76
-44
lines changed

modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ int PitchPseudoInterpolationLagPitchBuf(
7979
return 2 * lag + offset;
8080
}
8181

82-
// Refines a pitch period |inverted_lag| encoded as inverted lag with
83-
// pseudo-interpolation. The output sample rate is twice as that of
84-
// |inverted_lag|.
85-
int PitchPseudoInterpolationInvLagAutoCorr(
86-
int inverted_lag,
87-
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation) {
88-
int offset = 0;
89-
// Cannot apply pseudo-interpolation at the boundaries.
90-
if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) {
91-
offset = GetPitchPseudoInterpolationOffset(
92-
auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag],
93-
auto_correlation[inverted_lag - 1]);
94-
}
95-
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
96-
// be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
97-
return 2 * inverted_lag + offset;
98-
}
99-
10082
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
10183
// looking for sub-harmonics.
10284
// The values have been chosen to serve the following algorithm. Given the
@@ -129,44 +111,83 @@ struct Range {
129111
int max;
130112
};
131113

114+
// Number of analyzed pitches to the left(right) of a pitch candidate.
115+
constexpr int kPitchNeighborhoodRadius = 2;
116+
132117
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
133118
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
134119
// pitch buffer.
135120
Range CreateInvertedLagRange(int inverted_lag) {
136-
constexpr int kRadius = 2;
137-
return {std::max(inverted_lag - kRadius, 0),
138-
std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)};
121+
return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
122+
std::min(inverted_lag + kPitchNeighborhoodRadius,
123+
kInitialNumLags24kHz - 1)};
139124
}
140125

126+
constexpr int kNumPitchCandidates = 2; // Best and second best.
127+
// Maximum number of analyzed pitch periods.
128+
constexpr int kMaxPitchPeriods24kHz =
129+
kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
130+
131+
// Collection of inverted lags.
132+
class InvertedLagsIndex {
133+
public:
134+
InvertedLagsIndex() : num_entries_(0) {}
135+
// Adds an inverted lag to the index. Cannot add more than
136+
// `kMaxPitchPeriods24kHz` values.
137+
void Append(int inverted_lag) {
138+
RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
139+
inverted_lags_[num_entries_++] = inverted_lag;
140+
}
141+
const int* data() const { return inverted_lags_.data(); }
142+
int size() const { return num_entries_; }
143+
144+
private:
145+
std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
146+
int num_entries_;
147+
};
148+
141149
// Computes the auto correlation coefficients for the inverted lags in the
142-
// closed interval `inverted_lags`.
150+
// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
151+
// the inverted lags for the computed auto correlation values.
143152
void ComputeAutoCorrelation(
144153
Range inverted_lags,
145154
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
146-
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation) {
155+
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
156+
InvertedLagsIndex& inverted_lags_index) {
147157
// Check valid range.
148158
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
159+
// Trick to avoid zero initialization of `auto_correlation`.
160+
// Needed by the pseudo-interpolation.
161+
if (inverted_lags.min > 0) {
162+
auto_correlation[inverted_lags.min - 1] = 0.f;
163+
}
164+
if (inverted_lags.max < kInitialNumLags24kHz - 1) {
165+
auto_correlation[inverted_lags.max + 1] = 0.f;
166+
}
149167
// Check valid `inverted_lag` indexes.
150168
RTC_DCHECK_GE(inverted_lags.min, 0);
151-
RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size());
169+
RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
152170
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
153171
++inverted_lag) {
154172
auto_correlation[inverted_lag] =
155173
ComputeAutoCorrelation(inverted_lag, pitch_buffer);
174+
inverted_lags_index.Append(inverted_lag);
156175
}
157176
}
158177

159-
int ComputePitchPeriod24kHz(
178+
// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
179+
// 48 kHz.
180+
int ComputePitchPeriod48kHz(
160181
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
182+
rtc::ArrayView<const int> inverted_lags,
161183
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
162184
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
163185
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
164186
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
165187
int best_inverted_lag = 0; // Pitch period.
166188
float best_numerator = -1.f; // Pitch strength numerator.
167189
float best_denominator = 0.f; // Pitch strength denominator.
168-
for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz;
169-
++inverted_lag) {
190+
for (int inverted_lag : inverted_lags) {
170191
// A pitch candidate must have positive correlation.
171192
if (auto_correlation[inverted_lag] > 0.f) {
172193
// Auto-correlation energy normalized by frame energy.
@@ -181,7 +202,19 @@ int ComputePitchPeriod24kHz(
181202
}
182203
}
183204
}
184-
return best_inverted_lag;
205+
// Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
206+
// 48 kHz pitch period.
207+
if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
208+
// Cannot apply pseudo-interpolation at the boundaries.
209+
return best_inverted_lag * 2;
210+
}
211+
int offset = GetPitchPseudoInterpolationOffset(
212+
auto_correlation[best_inverted_lag + 1],
213+
auto_correlation[best_inverted_lag],
214+
auto_correlation[best_inverted_lag - 1]);
215+
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
216+
// be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
217+
return 2 * best_inverted_lag + offset;
185218
}
186219

187220
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
@@ -332,10 +365,10 @@ int ComputePitchPeriod48kHz(
332365
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
333366
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
334367
CandidatePitchPeriods pitch_candidates) {
335-
// Compute the auto-correlation terms only for neighbors of the given pitch
336-
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
337-
// for a few lag values).
338-
std::array<float, kInitialNumLags24kHz> auto_correlation{};
368+
// Compute the auto-correlation terms only for neighbors of the two pitch
369+
// candidates (best and second best).
370+
std::array<float, kInitialNumLags24kHz> auto_correlation;
371+
InvertedLagsIndex inverted_lags_index;
339372
// Create two inverted lag ranges so that `r1` precedes `r2`.
340373
const bool swap_candidates =
341374
pitch_candidates.best > pitch_candidates.second_best;
@@ -351,18 +384,17 @@ int ComputePitchPeriod48kHz(
351384
RTC_DCHECK_LE(r1.max, r2.max);
352385
if (r1.max + 1 >= r2.min) {
353386
// Overlapping or adjacent ranges.
354-
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation);
387+
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
388+
inverted_lags_index);
355389
} else {
356390
// Disjoint ranges.
357-
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation);
358-
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation);
391+
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
392+
inverted_lags_index);
393+
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
394+
inverted_lags_index);
359395
}
360-
// Find best pitch at 24 kHz.
361-
const int pitch_candidate_24kHz =
362-
ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy);
363-
// Pseudo-interpolation.
364-
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
365-
auto_correlation);
396+
return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
397+
auto_correlation, y_energy);
366398
}
367399

368400
PitchInfo ComputeExtendedPitchPeriod48kHz(

modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ class ExtendedPitchPeriodSearchParametrizaion
128128
TEST_P(ExtendedPitchPeriodSearchParametrizaion,
129129
PeriodBitExactnessGainWithinTolerance) {
130130
PitchTestData test_data;
131-
std::vector<float> y_energy(kMaxPitch24kHz + 1);
132-
rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
133-
kMaxPitch24kHz + 1);
131+
std::vector<float> y_energy(kRefineNumLags24kHz);
132+
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
133+
kRefineNumLags24kHz);
134134
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
135135
y_energy_view);
136136
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.

0 commit comments

Comments
 (0)