@@ -31,74 +31,104 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
31
31
rtc::ArrayView<float > h,
32
32
bool * filters_updated,
33
33
float * error_sum) {
34
+ const int h_size = static_cast <int >(h.size ());
35
+ const int x_size = static_cast <int >(x.size ());
36
+ RTC_DCHECK_EQ (0 , h_size % 4 );
37
+
34
38
// Process for all samples in the sub-block.
35
39
for (size_t i = 0 ; i < kSubBlockSize ; ++i) {
36
- // Apply the matched filter as filter * x. and compute x * x.
37
- float x2_sum = 0 . f ;
38
- float s = 0 ;
39
- size_t x_index = x_start_index;
40
- RTC_DCHECK_EQ ( 0 , h. size () % 4 ) ;
40
+ // Apply the matched filter as filter * x, and compute x * x.
41
+
42
+ RTC_DCHECK_GT (x_size, x_start_index) ;
43
+ const float * x_p = &x[ x_start_index] ;
44
+ const float * h_p = &h[ 0 ] ;
41
45
46
+ // Initialize values for the accumulation.
42
47
__m128 s_128 = _mm_set1_ps (0 );
43
48
__m128 x2_sum_128 = _mm_set1_ps (0 );
49
+ float x2_sum = 0 .f ;
50
+ float s = 0 ;
44
51
45
- size_t k = 0 ;
46
- if (h.size () > (x.size () - x_index)) {
47
- const size_t limit = x.size () - x_index;
48
- for (; (k + 3 ) < limit; k += 4 , x_index += 4 ) {
49
- const __m128 x_k = _mm_loadu_ps (&x[x_index]);
50
- const __m128 h_k = _mm_loadu_ps (&h[k]);
52
+ // Compute loop chunk sizes until, and after, the wraparound of the circular
53
+ // buffer for x.
54
+ const int chunk1 =
55
+ std::min (h_size, static_cast <int >(x_size - x_start_index));
56
+
57
+ // Perform the loop in two chunks.
58
+ const int chunk2 = h_size - chunk1;
59
+ for (int limit : {chunk1, chunk2}) {
60
+ // Perform 128 bit vector operations.
61
+ const int limit_by_4 = limit >> 2 ;
62
+ for (int k = limit_by_4; k > 0 ; --k, h_p += 4 , x_p += 4 ) {
63
+ // Load the data into 128 bit vectors.
64
+ const __m128 x_k = _mm_loadu_ps (x_p);
65
+ const __m128 h_k = _mm_loadu_ps (h_p);
51
66
const __m128 xx = _mm_mul_ps (x_k, x_k);
67
+ // Compute and accumulate x * x and h * x.
52
68
x2_sum_128 = _mm_add_ps (x2_sum_128, xx);
53
69
const __m128 hx = _mm_mul_ps (h_k, x_k);
54
70
s_128 = _mm_add_ps (s_128, hx);
55
71
}
56
72
57
- for (; k < limit; ++k, ++x_index) {
58
- x2_sum += x[x_index] * x[x_index];
59
- s += h[k] * x[x_index];
73
+ // Perform non-vector operations for any remaining items.
74
+ for (int k = limit - limit_by_4 * 4 ; k > 0 ; --k, ++h_p, ++x_p) {
75
+ const float x_k = *x_p;
76
+ x2_sum += x_k * x_k;
77
+ s += *h_p * x_k;
60
78
}
61
- x_index = 0 ;
62
- }
63
79
64
- for (; k + 3 < h.size (); k += 4 , x_index += 4 ) {
65
- const __m128 x_k = _mm_loadu_ps (&x[x_index]);
66
- const __m128 h_k = _mm_loadu_ps (&h[k]);
67
- const __m128 xx = _mm_mul_ps (x_k, x_k);
68
- x2_sum_128 = _mm_add_ps (x2_sum_128, xx);
69
- const __m128 hx = _mm_mul_ps (h_k, x_k);
70
- s_128 = _mm_add_ps (s_128, hx);
71
- }
72
-
73
- for (; k < h.size (); ++k, ++x_index) {
74
- x2_sum += x[x_index] * x[x_index];
75
- s += h[k] * x[x_index];
80
+ x_p = &x[0 ];
76
81
}
77
82
83
+ // Combine the accumulated vector and scalar values.
78
84
float * v = reinterpret_cast <float *>(&x2_sum_128);
79
85
x2_sum += v[0 ] + v[1 ] + v[2 ] + v[3 ];
80
86
v = reinterpret_cast <float *>(&s_128);
81
87
s += v[0 ] + v[1 ] + v[2 ] + v[3 ];
82
88
83
89
// Compute the matched filter error.
84
90
const float e = std::min (32767 .f , std::max (-32768 .f , y[i] - s));
85
- ( *error_sum) += e * e;
91
+ *error_sum += e * e;
86
92
87
93
// Update the matched filter estimate in an NLMS manner.
88
94
if (x2_sum > x2_sum_threshold) {
89
95
RTC_DCHECK_LT (0 .f , x2_sum);
90
96
const float alpha = 0 .7f * e / x2_sum;
97
+ const __m128 alpha_128 = _mm_set1_ps (alpha);
91
98
92
99
// filter = filter + 0.7 * (y - filter * x) / x * x.
93
- size_t x_index = x_start_index;
94
- for (size_t k = 0 ; k < h.size (); ++k) {
95
- h[k] += alpha * x[x_index];
96
- x_index = x_index < (x.size () - 1 ) ? x_index + 1 : 0 ;
100
+ float * h_p = &h[0 ];
101
+ x_p = &x[x_start_index];
102
+
103
+ // Perform the loop in two chunks.
104
+ for (int limit : {chunk1, chunk2}) {
105
+ // Perform 128 bit vector operations.
106
+ const int limit_by_4 = limit >> 2 ;
107
+ for (int k = limit_by_4; k > 0 ; --k, h_p += 4 , x_p += 4 ) {
108
+ // Load the data into 128 bit vectors.
109
+ __m128 h_k = _mm_loadu_ps (h_p);
110
+ const __m128 x_k = _mm_loadu_ps (x_p);
111
+
112
+ // Compute h = h + alpha * x.
113
+ const __m128 alpha_x = _mm_mul_ps (alpha_128, x_k);
114
+ h_k = _mm_add_ps (h_k, alpha_x);
115
+
116
+ // Store the result.
117
+ _mm_storeu_ps (h_p, h_k);
118
+ }
119
+
120
+ // Perform non-vector operations for any remaining items.
121
+ for (int k = limit - limit_by_4 * 4 ; k > 0 ; --k, ++h_p, ++x_p) {
122
+ *h_p += alpha * *x_p;
123
+ }
124
+
125
+ x_p = &x[0 ];
97
126
}
127
+
98
128
*filters_updated = true ;
99
129
}
100
130
101
- x_start_index = x_start_index > 0 ? x_start_index - 1 : x. size () - 1 ;
131
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1 ;
102
132
}
103
133
}
104
134
#endif
@@ -112,7 +142,7 @@ void MatchedFilterCore(size_t x_start_index,
112
142
float * error_sum) {
113
143
// Process for all samples in the sub-block.
114
144
for (size_t i = 0 ; i < kSubBlockSize ; ++i) {
115
- // Apply the matched filter as filter * x. and compute x * x.
145
+ // Apply the matched filter as filter * x, and compute x * x.
116
146
float x2_sum = 0 .f ;
117
147
float s = 0 ;
118
148
size_t x_index = x_start_index;
0 commit comments