Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit b213a16

Browse files
peahCommit bot
authored andcommitted
Finalized the SSE2 optimizations for the matched filter in AEC3
The SSE2 optimizations of the filter core in the matched filter was only half-done. This CL finalizes those. In particular: -It adds finalization of updating of the filter. -It removes the manual loop unrolling in order to reduce and simplify the code. Note that the changes pass the bitexactness tests in an external AEC3 test suite, and the test MatchedFilter.TestOptimizations succeed. BUG=webrtc:6018 Review-Url: https://codereview.webrtc.org/2813563003 Cr-Commit-Position: refs/heads/master@{#17655}
1 parent c0d74d9 commit b213a16

File tree

2 files changed

+66
-36
lines changed

2 files changed

+66
-36
lines changed

webrtc/modules/audio_processing/aec3/matched_filter.cc

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,74 +31,104 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
3131
rtc::ArrayView<float> h,
3232
bool* filters_updated,
3333
float* error_sum) {
34+
const int h_size = static_cast<int>(h.size());
35+
const int x_size = static_cast<int>(x.size());
36+
RTC_DCHECK_EQ(0, h_size % 4);
37+
3438
// Process for all samples in the sub-block.
3539
for (size_t i = 0; i < kSubBlockSize; ++i) {
36-
// Apply the matched filter as filter * x. and compute x * x.
37-
float x2_sum = 0.f;
38-
float s = 0;
39-
size_t x_index = x_start_index;
40-
RTC_DCHECK_EQ(0, h.size() % 4);
40+
// Apply the matched filter as filter * x, and compute x * x.
41+
42+
RTC_DCHECK_GT(x_size, x_start_index);
43+
const float* x_p = &x[x_start_index];
44+
const float* h_p = &h[0];
4145

46+
// Initialize values for the accumulation.
4247
__m128 s_128 = _mm_set1_ps(0);
4348
__m128 x2_sum_128 = _mm_set1_ps(0);
49+
float x2_sum = 0.f;
50+
float s = 0;
4451

45-
size_t k = 0;
46-
if (h.size() > (x.size() - x_index)) {
47-
const size_t limit = x.size() - x_index;
48-
for (; (k + 3) < limit; k += 4, x_index += 4) {
49-
const __m128 x_k = _mm_loadu_ps(&x[x_index]);
50-
const __m128 h_k = _mm_loadu_ps(&h[k]);
52+
// Compute loop chunk sizes until, and after, the wraparound of the circular
53+
// buffer for x.
54+
const int chunk1 =
55+
std::min(h_size, static_cast<int>(x_size - x_start_index));
56+
57+
// Perform the loop in two chunks.
58+
const int chunk2 = h_size - chunk1;
59+
for (int limit : {chunk1, chunk2}) {
60+
// Perform 128 bit vector operations.
61+
const int limit_by_4 = limit >> 2;
62+
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
63+
// Load the data into 128 bit vectors.
64+
const __m128 x_k = _mm_loadu_ps(x_p);
65+
const __m128 h_k = _mm_loadu_ps(h_p);
5166
const __m128 xx = _mm_mul_ps(x_k, x_k);
67+
// Compute and accumulate x * x and h * x.
5268
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
5369
const __m128 hx = _mm_mul_ps(h_k, x_k);
5470
s_128 = _mm_add_ps(s_128, hx);
5571
}
5672

57-
for (; k < limit; ++k, ++x_index) {
58-
x2_sum += x[x_index] * x[x_index];
59-
s += h[k] * x[x_index];
73+
// Perform non-vector operations for any remaining items.
74+
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
75+
const float x_k = *x_p;
76+
x2_sum += x_k * x_k;
77+
s += *h_p * x_k;
6078
}
61-
x_index = 0;
62-
}
6379

64-
for (; k + 3 < h.size(); k += 4, x_index += 4) {
65-
const __m128 x_k = _mm_loadu_ps(&x[x_index]);
66-
const __m128 h_k = _mm_loadu_ps(&h[k]);
67-
const __m128 xx = _mm_mul_ps(x_k, x_k);
68-
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
69-
const __m128 hx = _mm_mul_ps(h_k, x_k);
70-
s_128 = _mm_add_ps(s_128, hx);
71-
}
72-
73-
for (; k < h.size(); ++k, ++x_index) {
74-
x2_sum += x[x_index] * x[x_index];
75-
s += h[k] * x[x_index];
80+
x_p = &x[0];
7681
}
7782

83+
// Combine the accumulated vector and scalar values.
7884
float* v = reinterpret_cast<float*>(&x2_sum_128);
7985
x2_sum += v[0] + v[1] + v[2] + v[3];
8086
v = reinterpret_cast<float*>(&s_128);
8187
s += v[0] + v[1] + v[2] + v[3];
8288

8389
// Compute the matched filter error.
8490
const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
85-
(*error_sum) += e * e;
91+
*error_sum += e * e;
8692

8793
// Update the matched filter estimate in an NLMS manner.
8894
if (x2_sum > x2_sum_threshold) {
8995
RTC_DCHECK_LT(0.f, x2_sum);
9096
const float alpha = 0.7f * e / x2_sum;
97+
const __m128 alpha_128 = _mm_set1_ps(alpha);
9198

9299
// filter = filter + 0.7 * (y - filter * x) / x * x.
93-
size_t x_index = x_start_index;
94-
for (size_t k = 0; k < h.size(); ++k) {
95-
h[k] += alpha * x[x_index];
96-
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
100+
float* h_p = &h[0];
101+
x_p = &x[x_start_index];
102+
103+
// Perform the loop in two chunks.
104+
for (int limit : {chunk1, chunk2}) {
105+
// Perform 128 bit vector operations.
106+
const int limit_by_4 = limit >> 2;
107+
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
108+
// Load the data into 128 bit vectors.
109+
__m128 h_k = _mm_loadu_ps(h_p);
110+
const __m128 x_k = _mm_loadu_ps(x_p);
111+
112+
// Compute h = h + alpha * x.
113+
const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
114+
h_k = _mm_add_ps(h_k, alpha_x);
115+
116+
// Store the result.
117+
_mm_storeu_ps(h_p, h_k);
118+
}
119+
120+
// Perform non-vector operations for any remaining items.
121+
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
122+
*h_p += alpha * *x_p;
123+
}
124+
125+
x_p = &x[0];
97126
}
127+
98128
*filters_updated = true;
99129
}
100130

101-
x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
131+
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
102132
}
103133
}
104134
#endif
@@ -112,7 +142,7 @@ void MatchedFilterCore(size_t x_start_index,
112142
float* error_sum) {
113143
// Process for all samples in the sub-block.
114144
for (size_t i = 0; i < kSubBlockSize; ++i) {
115-
// Apply the matched filter as filter * x. and compute x * x.
145+
// Apply the matched filter as filter * x, and compute x * x.
116146
float x2_sum = 0.f;
117147
float s = 0;
118148
size_t x_index = x_start_index;

webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ TEST(MatchedFilter, TestOptimizations) {
7474
EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
7575

7676
for (size_t j = 0; j < h.size(); ++j) {
77-
EXPECT_NEAR(h[j], h_SSE2[j], 0.001f);
77+
EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
7878
}
7979

8080
x_index = (x_index + kSubBlockSize) % x.size();

0 commit comments

Comments
 (0)