88namespace stdx = std::experimental;
99namespace cp_algo ::math::fft {
1010 using ftype = double ;
11+ static constexpr size_t bytes = 32 ;
12+ static constexpr size_t flen = bytes / sizeof (ftype);
1113 using point = complex <ftype>;
12- using vftype = stdx::native_simd< ftype> ;
14+ using vftype [[gnu::vector_size(bytes)]] = ftype;
1315 using vpoint = complex <vftype>;
14- static constexpr size_t flen = vftype::size() ;
16+ static constexpr vftype fz = {} ;
1517
1618 struct cvector {
1719 std::vector<vftype> x, y;
@@ -117,15 +119,14 @@ namespace cp_algo::math::fft {
117119 rt = -rt;
118120 }
119121 auto [Bvx, Bvy] = B.vget (k);
120- auto [Brvx, Brvy] = vpoint (Bvx, Bvy) * vpoint (real (rt), imag (rt));
122+ auto [Brvx, Brvy] = vpoint (Bvx, Bvy) * vpoint (fz + real (rt), fz + imag (rt));
121123 auto [Ax, Ay] = A.vget (k);
122124 vftype Bx[2 ] = {Brvx, Bvx}, By[2 ] = {Brvy, Bvy};
123- vpoint res = {0 , 0 };
125+ vpoint res = {fz, fz };
124126 for (size_t i = 0 ; i < flen; i++) {
125- vftype Bsx, Bsy;
126- Bsx.copy_from ((ftype*)Bx + flen - i, stdx::element_aligned);
127- Bsy.copy_from ((ftype*)By + flen - i, stdx::element_aligned);
128- res += vpoint (Ax[i], Ay[i]) * vpoint{Bsx, Bsy};
127+ auto Bsx = (vftype*)((ftype*)Bx + flen - i);
128+ auto Bsy = (vftype*)((ftype*)By + flen - i);
129+ res += vpoint (fz + Ax[i], fz + Ay[i]) * vpoint{*Bsx, *Bsy};
129130 }
130131 return res;
131132 }
@@ -144,7 +145,7 @@ namespace cp_algo::math::fft {
144145 if (4 * i <= n) { // radix-4
145146 exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
146147 k *= 4 * i;
147- vpoint v1 = {real (rt), - imag (rt)};
148+ vpoint v1 = {fz + real (rt), fz - imag (rt)};
148149 vpoint v2 = v1 * v1;
149150 vpoint v3 = v1 * v2;
150151 for (size_t j = k; j < k + i; j += flen) {
@@ -154,15 +155,15 @@ namespace cp_algo::math::fft {
154155 auto D = get<vpoint>(j + 3 * i);
155156 set (j , (A + B + C + D));
156157 set (j + 2 * i, (A + B - C - D) * v2);
157- set (j + i, (A - B - vpoint (0 , 1 ) * (C - D)) * v1);
158- set (j + 3 * i, (A - B + vpoint (0 , 1 ) * (C - D)) * v3);
158+ set (j + i, (A - B - vpoint (fz, fz + 1 ) * (C - D)) * v1);
159+ set (j + 3 * i, (A - B + vpoint (fz, fz + 1 ) * (C - D)) * v3);
159160 }
160161 });
161162 i *= 2 ;
162163 } else { // radix-2 fallback
163164 exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
164165 k *= 2 * i;
165- vpoint cvrt = {real (rt), - imag (rt)};
166+ vpoint cvrt = {fz + real (rt), fz - imag (rt)};
166167 for (size_t j = k; j < k + i; j += flen) {
167168 auto A = get<vpoint>(j) + get<vpoint>(j + i);
168169 auto B = get<vpoint>(j) - get<vpoint>(j + i);
@@ -174,7 +175,7 @@ namespace cp_algo::math::fft {
174175 }
175176 checkpoint (" ifft" );
176177 for (size_t k = 0 ; k < n; k += flen) {
177- set (k, get<vpoint>(k) /= (ftype)(n / flen));
178+ set (k, get<vpoint>(k) /= fz + (ftype)(n / flen));
178179 }
179180 }
180181 void fft () {
@@ -184,7 +185,7 @@ namespace cp_algo::math::fft {
184185 i /= 2 ;
185186 exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
186187 k *= 4 * i;
187- vpoint v1 = {real (rt), imag (rt)};
188+ vpoint v1 = {fz + real (rt), fz + imag (rt)};
188189 vpoint v2 = v1 * v1;
189190 vpoint v3 = v1 * v2;
190191 for (size_t j = k; j < k + i; j += flen) {
@@ -194,14 +195,14 @@ namespace cp_algo::math::fft {
194195 auto D = get<vpoint>(j + 3 * i) * v3;
195196 set (j , (A + C) + (B + D));
196197 set (j + i, (A + C) - (B + D));
197- set (j + 2 * i, (A - C) + vpoint (0 , 1 ) * (B - D));
198- set (j + 3 * i, (A - C) - vpoint (0 , 1 ) * (B - D));
198+ set (j + 2 * i, (A - C) + vpoint (fz, fz + 1 ) * (B - D));
199+ set (j + 3 * i, (A - C) - vpoint (fz, fz + 1 ) * (B - D));
199200 }
200201 });
201202 } else { // radix-2 fallback
202203 exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
203204 k *= 2 * i;
204- vpoint vrt = {real (rt), imag (rt)};
205+ vpoint vrt = {fz + real (rt), fz + imag (rt)};
205206 for (size_t j = k; j < k + i; j += flen) {
206207 auto t = get<vpoint>(j + i) * vrt;
207208 set (j + i, get<vpoint>(j) - t);
0 commit comments