44
55@inline _conj (w:: Complex , d:: Direction ) = ifelse (direction_sign (d) === 1 , w, conj (w))
66
7- function fft! (out:: AbstractVector{T} , in:: AbstractVector{T} , start_out:: Int , start_in:: Int , d:: Direction , t:: FFTEnum , g:: CallGraph{T} , idx:: Int ) where T
7+ function fft! (
8+ out:: AbstractVector{T} , in:: AbstractVector{T} ,
9+ start_out:: Int , start_in:: Int ,
10+ d:: Direction ,
11+ t:: FFTEnum ,
12+ g:: CallGraph{T} ,
13+ idx:: Int
14+ ) where T
815 if t === COMPOSITE_FFT
916 fft_composite! (out, in, start_out, start_in, d, g, idx)
1017 else
1118 root = g[idx]
12- if t == DFT
13- fft_dft! (out, in, root. sz, start_out, root. s_out, start_in, root. s_in, _conj (root. w, d))
19+ s_in = root. s_in
20+ s_out = root. s_out
21+ N = root. sz
22+ w = _conj (root. w, d)
23+ if t === DFT
24+ fft_dft! (out, in, N, start_out, s_out, start_in, s_in, w)
25+ elseif t === POW2RADIX4_FFT
26+ fft_pow2_radix4! (out, in, N, start_out, s_out, start_in, s_in, w)
27+ elseif t === POW3_FFT
28+ _m_120 = cispi (T (2 ) / 3 )
29+ m_120 = d === FFT_FORWARD ? _m_120 : conj (_m_120)
30+ fft_pow3! (out, in, N, start_out, s_out, start_in, s_in, w, m_120)
31+ elseif t === BLUESTEIN
32+ fft_bluestein! (out, in, d, N, start_out, s_out, start_in, s_in)
1433 else
15- s_in = root. s_in
16- s_out = root. s_out
17- if t === POW2RADIX4_FFT
18- fft_pow2_radix4! (out, in, root. sz, start_out, s_out, start_in, s_in, _conj (root. w, d))
19- elseif t === POW3_FFT
20- p_120 = cispi (T (2 )/ 3 )
21- m_120 = cispi (T (4 )/ 3 )
22- _p_120, _m_120 = d == FFT_FORWARD ? (p_120, m_120) : (m_120, p_120)
23- fft_pow3! (out, in, root. sz, start_out, s_out, start_in, s_in, _conj (root. w, d), _m_120, _p_120)
24- else
25- throw (ArgumentError (" kernel not implemented" ))
26- end
34+ throw (ArgumentError (" kernel not implemented" ))
2735 end
2836 end
2937end
@@ -49,27 +57,58 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out
4957 right_idx = idx + root. right
5058 left = g[left_idx]
5159 right = g[right_idx]
52- N = root. sz
60+ # N = root.sz
5361 N1 = left. sz
5462 N2 = right. sz
5563 s_in = root. s_in
5664 s_out = root. s_out
5765
66+ Rt = right. type
67+ Lt = left. type
68+
5869 w1 = _conj (root. w, d)
5970 wj1 = one (T)
6071 tmp = g. workspace[idx]
61- @inbounds for j1 in 0 : N1- 1
72+
73+ if Rt === BLUESTEIN
74+ R_scratch = prealloc_blue (N2, d, T)
75+ end
76+ for j1 in 0 : N1- 1
6277 wk2 = wj1
63- fft! (tmp, in, N2* j1+ 1 , start_in + j1* s_in, d, right. type, g, right_idx)
64- j1 > 0 && @inbounds for k2 in 1 : N2- 1
65- tmp[N2* j1 + k2 + 1 ] *= wk2
66- wk2 *= wj1
78+ R_start_in = start_in + j1 * s_in
79+ R_start_out = 1 + N2 * j1
80+
81+ if Rt === BLUESTEIN
82+ R_s_in = right. s_in
83+ R_s_out = right. s_out
84+ fft_bluestein! (tmp, in, d, N2, R_start_out, R_s_out, R_start_in, R_s_in, R_scratch)
85+ else
86+ fft! (tmp, in, R_start_out, R_start_in, d, Rt, g, right_idx)
6787 end
88+
89+ if j1 > 0
90+ @inbounds for k2 in 1 : N2- 1
91+ tmp[R_start_out + k2] *= wk2
92+ wk2 *= wj1
93+ end
94+ end
95+
6896 wj1 *= w1
6997 end
7098
71- @inbounds for k2 in 0 : N2- 1
72- fft! (out, tmp, start_out + k2* s_out, k2+ 1 , d, left. type, g, left_idx)
99+ if Lt === BLUESTEIN
100+ L_scratch = prealloc_blue (N1, d, T)
101+ end
102+ for k2 in 0 : N2- 1
103+ L_start_out = start_out + k2 * s_out
104+ L_start_in = 1 + k2
105+ if Lt === BLUESTEIN
106+ L_s_in = left. s_in
107+ L_s_out = left. s_out
108+ fft_bluestein! (out, tmp, d, N1, L_start_out, L_s_out, L_start_in, L_s_in, L_scratch)
109+ else
110+ fft! (out, tmp, L_start_out, L_start_in, d, Lt, g, left_idx)
111+ end
73112 end
74113end
75114
@@ -178,7 +217,7 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int,
178217 w1 = w
179218 w2 = w * w1
180219 w3 = w * w2
181- w4 = w * w3
220+ w4 = w2 * w2
182221
183222 fft_pow2_radix4! (out, in, m, start_out , stride_out, start_in , stride_in* 4 , w4)
184223 fft_pow2_radix4! (out, in, m, start_out + m* stride_out, stride_out, start_in + stride_in, stride_in* 4 , w4)
@@ -228,7 +267,8 @@ Power of 3 FFT, in place
228267- `minus120`: Depending on direction, perform either ∓120° rotation
229268
230269"""
231- function fft_pow3! (out:: AbstractVector{T} , in:: AbstractVector{U} , N:: Int , start_out:: Int , stride_out:: Int , start_in:: Int , stride_in:: Int , w:: T , plus120:: T , minus120:: T ) where {T, U}
270+ function fft_pow3! (out:: AbstractVector{T} , in:: AbstractVector{U} , N:: Int , start_out:: Int , stride_out:: Int , start_in:: Int , stride_in:: Int , w:: T , minus120:: T ) where {T, U}
271+ plus120 = conj (minus120)
232272 if N == 3
233273 @muladd out[start_out + 0 ] = in[start_in] + in[start_in + stride_in] + in[start_in + 2 * stride_in]
234274 @muladd out[start_out + stride_out] = in[start_in] + in[start_in + stride_in]* plus120 + in[start_in + 2 * stride_in]* minus120
@@ -240,17 +280,17 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
240280 Nprime = N ÷ 3
241281
242282 # Dividing into subproblems
243- fft_pow3! (out, in, Nprime, start_out, stride_out, start_in, stride_in* 3 , w^ 3 , plus120 , minus120)
244- fft_pow3! (out, in, Nprime, start_out + Nprime* stride_out, stride_out, start_in + stride_in, stride_in* 3 , w^ 3 , plus120 , minus120)
245- fft_pow3! (out, in, Nprime, start_out + 2 * Nprime* stride_out, stride_out, start_in + 2 * stride_in, stride_in* 3 , w^ 3 , plus120, minus120)
283+ fft_pow3! (out, in, Nprime, start_out, stride_out, start_in, stride_in* 3 , w^ 3 , minus120)
284+ fft_pow3! (out, in, Nprime, start_out + Nprime* stride_out, stride_out, start_in + stride_in, stride_in* 3 , w^ 3 , minus120)
285+ fft_pow3! (out, in, Nprime, start_out + 2 * Nprime* stride_out, stride_out, start_in + 2 * stride_in, stride_in* 3 , w^ 3 , minus120)
246286
247287 w1 = w
248288 w2 = w * w1
249289 wk1 = wk2 = one (T)
250290 for k in 0 : Nprime- 1
251- @muladd k0 = start_out + stride_out * k
252- @muladd k1 = start_out + stride_out * (k + Nprime)
253- @muladd k2 = start_out + stride_out * (k + 2 * Nprime)
291+ k0 = start_out + stride_out * k
292+ k1 = start_out + stride_out * (k + Nprime)
293+ k2 = start_out + stride_out * (k + 2 * Nprime)
254294 y_k0, y_k1, y_k2 = out[k0], out[k1], out[k2]
255295 @muladd out[k0] = y_k0 + y_k1 * wk1 + y_k2 * wk2
256296 @muladd out[k1] = y_k0 + y_k1 * wk1 * plus120 + y_k2 * wk2 * minus120
@@ -259,3 +299,87 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
259299 wk2 *= w2
260300 end
261301end
302+
303+
304+ function prealloc_blue (N:: Int , d:: Direction , :: Type{T} ) where T<: Number
305+ pad_len = nextpow (2 , 2 N - 1 )
306+
307+ b_series = Vector {T} (undef, pad_len)
308+ a_series = Vector {T} (undef, pad_len)
309+ tmp = Vector {T} (undef, pad_len)
310+
311+ b_series[N+ 1 : end ] .= zero (T)
312+
313+ sgn = - direction_sign (d)
314+ p = 0 # n^2
315+ for i in 1 : N
316+ b_series[i] = cispi (sgn * p / N)
317+ p += (2 i - 1 ) # prevents overflow unless N is absolutely massive
318+ p > N && (p -= 2 N)
319+ end
320+
321+ # enforce periodic boundaries for b_n
322+ for j in 0 : N- 1
323+ b_series[pad_len- j] = b_series[2 + j]
324+ end
325+
326+ return (tmp, a_series, b_series, pad_len)
327+ end
328+
329+ """
330+ $(TYPEDSIGNATURES)
331+ Bluestein's algorithm, still O(N * log(N)) for large primes,
332+ but with a big constant factor.
333+ Zero-pads two sequences derived from the DFT formula to a
334+ power of 2 length greater than `2N-1` and computes their convolution
335+ with a power 2 FFT.
336+
337+ # Arguments
338+ - `out`: Output vector
339+ - `in`: Input vector
340+ - `d`: Direction of the transform
341+ - `N`: Size of the transform
342+ - `start_out`: Index of the first element of the output vector
343+ - `stride_out`: Stride of the output vector
344+ - `start_in`: Index of the first element of the input vector
345+ - `stride_in`: Stride of the input vector
346+ - `w`: The value `cispi(direction_sign(d) * 2 / N)`
347+
348+ """
349+ function fft_bluestein! (
350+ out:: AbstractVector{T} , in:: AbstractVector{T} ,
351+ d:: Direction ,
352+ N:: Int ,
353+ start_out:: Int , stride_out:: Int ,
354+ start_in:: Int , stride_in:: Int ,
355+ scratch:: Tuple{Vector{T},Vector{T},Vector{T},Int} = prealloc_blue (N, d, T)
356+ ) where T<: Number
357+
358+ (tmp, a_series, b_series, pad_len) = scratch
359+
360+ a_series[N+ 1 : end ] .= zero (T)
361+ tmp[N+ 1 : end ] .= zero (T)
362+
363+ for i in 1 : N
364+ a_series[i] = in[start_in+ (i- 1 )* stride_in] * conj (b_series[i])
365+ end
366+
367+ w_pad = cispi (T (2 ) / pad_len)
368+ # leave b_n vector alone for last step
369+ fft_pow2_radix4! (tmp, a_series, pad_len, 1 , 1 , 1 , 1 , w_pad) # Fa
370+ fft_pow2_radix4! (a_series, b_series, pad_len, 1 , 1 , 1 , 1 , w_pad) # Fb
371+
372+ tmp .*= a_series
373+ # convolution theorem ifft
374+ fft_pow2_radix4! (a_series, tmp, pad_len, 1 , 1 , 1 , 1 , conj (w_pad))
375+ conv_a_b = a_series
376+
377+ Xk = tmp
378+ for i in 1 : N
379+ Xk[i] = conj (b_series[i]) * conv_a_b[i] / pad_len
380+ end
381+
382+ out_inds = range (start_out; step= stride_out, length= N)
383+ copyto! (out, CartesianIndices ((out_inds,)), Xk, CartesianIndices ((N,)))
384+ return nothing
385+ end
0 commit comments