Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit c36f862

Browse files
alebzkCommit Bot
authored andcommitted
Reland "RNN VAD: pitch search optimizations (part 1)"
This reverts commit 1b6b958. Reason for revert: Bug fix Original change's description: > Revert "RNN VAD: pitch search optimizations (part 1)" > > This reverts commit 9da3e17. > > Reason for revert: bug in ComputePitchPeriod48kHz() > > Original change's description: > > RNN VAD: pitch search optimizations (part 1) > > > > TL;DR this CL improves efficiency and includes several code > > readability improvements mainly triggered by the comments to > > patch set #10. > > > > Highlights: > > - Split `FindBestPitchPeriods()` into 12 and 24 kHz versions > > to hard-code the input size and simplify the 24 kHz version > > - Loop in `ComputePitchPeriod48kHz()` (new name for > > `RefinePitchPeriod48kHz()`) removed since the lags for which > > we need to compute the auto correlation are a few > > - `ComputePitchGainThreshold()` was only used in unit tests; it's been > > moved into the anon ns and the test removed > > > > This CL makes `ComputePitchPeriod48kHz()` is about 10% faster (measured > > with https://webrtc-review.googlesource.com/c/src/+/191320/4/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc). > > The realtime factor has improved by about +14%. > > > > Benchmarked as follows: > > ``` > > out/release/modules_unittests \ > > --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \ > > --gtest_also_run_disabled_tests --logs > > ``` > > > > Results: > > > > | baseline | this CL > > ------+----------------------+------------------------ > > run 1 | 24.0231 +/- 0.591016 | 23.568 +/- 0.990788 > > | 370.06x | 377.207x > > ------+----------------------+------------------------ > > run 2 | 24.0485 +/- 0.957498 | 23.3714 +/- 0.857523 > > | 369.67x | 380.379x > > ------+----------------------+------------------------ > > run 2 | 25.4091 +/- 2.6123 | 23.709 +/- 1.04477 > > | 349.875x | 374.963x > > > > Bug: webrtc:10480 > > Change-Id: I9a3e9164b2442114b928de506c92a547c273882f > > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191320 > > Reviewed-by: Per Åhgren <[email protected]> > > Commit-Queue: Alessio Bazzica <[email protected]> > > Cr-Commit-Position: refs/heads/master@{#32568} > > [email protected],[email protected] > > No-Presubmit: true > No-Tree-Checks: true > No-Try: true > Bug: webrtc:10480 > Change-Id: I2a91f4f29566f872a7dfa220b31c6c625ed075db > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192660 > Commit-Queue: Alessio Bazzica <[email protected]> > Reviewed-by: Alessio Bazzica <[email protected]> > Cr-Commit-Position: refs/heads/master@{#32581} [email protected],[email protected] # Not skipping CQ checks because this is a reland. Bug: webrtc:10480 Change-Id: I66e3e8d73ebc04a437c01a0396cd5613c42a8cf5 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192780 Reviewed-by: Alessio Bazzica <[email protected]> Reviewed-by: Per Åhgren <[email protected]> Commit-Queue: Alessio Bazzica <[email protected]> Cr-Commit-Position: refs/heads/master@{#32585}
1 parent 01a36f3 commit c36f862

13 files changed

+488
-451
lines changed

modules/audio_processing/agc2/rnn_vad/BUILD.gn

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ rtc_library("rnn_vad_lp_residual") {
8383

8484
rtc_library("rnn_vad_pitch") {
8585
sources = [
86-
"pitch_info.h",
8786
"pitch_search.cc",
8887
"pitch_search.h",
8988
"pitch_search_internal.cc",
@@ -94,6 +93,7 @@ rtc_library("rnn_vad_pitch") {
9493
":rnn_vad_common",
9594
"../../../../api:array_view",
9695
"../../../../rtc_base:checks",
96+
"../../../../rtc_base:gtest_prod",
9797
"../../../../rtc_base:safe_compare",
9898
"../../../../rtc_base:safe_conversions",
9999
]

modules/audio_processing/agc2/rnn_vad/auto_correlation.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace {
2020

2121
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
2222
static_assert(1 << kAutoCorrelationFftOrder >
23-
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
23+
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
2424
"");
2525

2626
} // namespace
@@ -45,15 +45,15 @@ AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
4545
// pitch period.
4646
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
4747
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
48-
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
48+
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
4949
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
5050
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
5151
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
5252
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
5353
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
5454
"Mismatch between pitch buffer size, frame size and maximum "
5555
"pitch period.");
56-
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
56+
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
5757
"The FFT length is not sufficiently big to avoid cyclic "
5858
"convolution errors.");
5959
auto tmp = tmp_->GetView();
@@ -67,13 +67,12 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
6767

6868
// Compute the FFT for the sliding frames chunk. The sliding frames are
6969
// defined as pitch_buf[i:i+kConvolutionLength] where i in
70-
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
71-
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
70+
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
71+
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
7272
std::copy(pitch_buf.begin(),
73-
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
73+
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
7474
tmp.begin());
75-
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
76-
0.f);
75+
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
7776
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
7877

7978
// Convolve in the frequency domain.
@@ -84,7 +83,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
8483

8584
// Extract the auto-correlation coefficients.
8685
std::copy(tmp.begin() + kConvolutionLength - 1,
87-
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
86+
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
8887
auto_corr.begin());
8988
}
9089

modules/audio_processing/agc2/rnn_vad/auto_correlation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class AutoCorrelationCalculator {
3434
// |auto_corr| indexes are inverted lags.
3535
void ComputeOnPitchBuffer(
3636
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
37-
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
37+
rtc::ArrayView<float, kNumLags12kHz> auto_corr);
3838

3939
private:
4040
Pffft fft_;

modules/audio_processing/agc2/rnn_vad/common.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
3636
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
3737
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
3838
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
39-
constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
39+
// Number of (inverted) lags during the initial pitch search phase at 24 kHz.
40+
constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
41+
// Number of (inverted) lags during the pitch search refinement phase at 24 kHz.
42+
constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1;
43+
static_assert(
44+
kRefineNumLags24kHz > kInitialNumLags24kHz,
45+
"The refinement step must search the pitch in an extended pitch range.");
4046

4147
// 12 kHz analysis.
4248
constexpr int kSampleRate12kHz = 12000;
@@ -47,8 +53,8 @@ constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
4753
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
4854
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
4955
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
50-
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
51-
constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
56+
// |kMaxPitch12kHz|] are in the range [0, |kNumLags12kHz|].
57+
constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
5258

5359
// 48 kHz constants.
5460
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;

modules/audio_processing/agc2/rnn_vad/features_extraction.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,12 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
6767
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
6868
// Estimate pitch on the LP-residual and write the normalized pitch period
6969
// into the output vector (normalization based on training data stats).
70-
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
71-
feature_vector[kFeatureVectorSize - 2] =
72-
0.01f * (pitch_info_48kHz_.period - 300);
70+
pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
71+
feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300);
7372
// Extract lagged frames (according to the estimated pitch period).
74-
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
73+
RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz);
7574
auto lagged_frame = pitch_buf_24kHz_view_.subview(
76-
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
75+
kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz);
7776
// Analyze reference and lagged frames checking if silence has been detected
7877
// and write the feature vector.
7978
return spectral_features_extractor_.CheckSilenceComputeFeatures(

modules/audio_processing/agc2/rnn_vad/features_extraction.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "api/array_view.h"
1717
#include "modules/audio_processing/agc2/biquad_filter.h"
1818
#include "modules/audio_processing/agc2/rnn_vad/common.h"
19-
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
2019
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
2120
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
2221
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
@@ -53,7 +52,7 @@ class FeaturesExtractor {
5352
PitchEstimator pitch_estimator_;
5453
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
5554
SpectralFeaturesExtractor spectral_features_extractor_;
56-
PitchInfo pitch_info_48kHz_;
55+
int pitch_period_48kHz_;
5756
};
5857

5958
} // namespace rnn_vad

modules/audio_processing/agc2/rnn_vad/pitch_info.h

Lines changed: 0 additions & 29 deletions
This file was deleted.

modules/audio_processing/agc2/rnn_vad/pitch_search.cc

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,37 @@ namespace rnn_vad {
2121
PitchEstimator::PitchEstimator()
2222
: pitch_buf_decimated_(kBufSize12kHz),
2323
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
24-
auto_corr_(kNumInvertedLags12kHz),
25-
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
24+
auto_corr_(kNumLags12kHz),
25+
auto_corr_view_(auto_corr_.data(), kNumLags12kHz) {
2626
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
27-
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
27+
RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size());
2828
}
2929

3030
PitchEstimator::~PitchEstimator() = default;
3131

32-
PitchInfo PitchEstimator::Estimate(
33-
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
32+
int PitchEstimator::Estimate(
33+
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
3434
// Perform the initial pitch search at 12 kHz.
35-
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
35+
Decimate2x(pitch_buffer, pitch_buf_decimated_view_);
3636
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
3737
auto_corr_view_);
38-
CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods(
39-
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
38+
CandidatePitchPeriods pitch_candidates_inverted_lags =
39+
ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_);
4040
// Refine the pitch period estimation.
4141
// The refinement is done using the pitch buffer that contains 24 kHz samples.
4242
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
4343
// to 24 kHz.
4444
pitch_candidates_inverted_lags.best *= 2;
4545
pitch_candidates_inverted_lags.second_best *= 2;
4646
const int pitch_inv_lag_48kHz =
47-
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags);
47+
ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags);
4848
// Look for stronger harmonics to find the final pitch period and its gain.
4949
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
50-
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
51-
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
52-
return last_pitch_48kHz_;
50+
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
51+
pitch_buffer,
52+
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz,
53+
last_pitch_48kHz_);
54+
return last_pitch_48kHz_.period;
5355
}
5456

5557
} // namespace rnn_vad

modules/audio_processing/agc2/rnn_vad/pitch_search.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#include "api/array_view.h"
1818
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
1919
#include "modules/audio_processing/agc2/rnn_vad/common.h"
20-
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
2120
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
21+
#include "rtc_base/gtest_prod_util.h"
2222

2323
namespace webrtc {
2424
namespace rnn_vad {
@@ -30,17 +30,21 @@ class PitchEstimator {
3030
PitchEstimator(const PitchEstimator&) = delete;
3131
PitchEstimator& operator=(const PitchEstimator&) = delete;
3232
~PitchEstimator();
33-
// Estimates the pitch period and gain. Returns the pitch estimation data for
34-
// 48 kHz.
35-
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);
33+
// Returns the estimated pitch period at 48 kHz.
34+
int Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer);
3635

3736
private:
38-
PitchInfo last_pitch_48kHz_;
37+
FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance);
38+
float GetLastPitchStrengthForTesting() const {
39+
return last_pitch_48kHz_.strength;
40+
}
41+
42+
PitchInfo last_pitch_48kHz_{};
3943
AutoCorrelationCalculator auto_corr_calculator_;
4044
std::vector<float> pitch_buf_decimated_;
4145
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
4246
std::vector<float> auto_corr_;
43-
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
47+
rtc::ArrayView<float, kNumLags12kHz> auto_corr_view_;
4448
};
4549

4650
} // namespace rnn_vad

0 commit comments

Comments
 (0)