Skip to content

Commit bf2c0eb

Browse files
authored
Merge pull request #107 from wheeheee/bluestein
Implement Bluestein's algorithm for large primes (>100)
2 parents da36b22 + 473680b commit bf2c0eb

File tree

9 files changed

+223
-91
lines changed

9 files changed

+223
-91
lines changed

src/algos.jl

Lines changed: 155 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,34 @@ end
44

55
@inline _conj(w::Complex, d::Direction) = ifelse(direction_sign(d) === 1, w, conj(w))
66

7-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, start_out::Int, start_in::Int, d::Direction, t::FFTEnum, g::CallGraph{T}, idx::Int) where T
7+
function fft!(
8+
out::AbstractVector{T}, in::AbstractVector{T},
9+
start_out::Int, start_in::Int,
10+
d::Direction,
11+
t::FFTEnum,
12+
g::CallGraph{T},
13+
idx::Int
14+
) where T
815
if t === COMPOSITE_FFT
916
fft_composite!(out, in, start_out, start_in, d, g, idx)
1017
else
1118
root = g[idx]
12-
if t == DFT
13-
fft_dft!(out, in, root.sz, start_out, root.s_out, start_in, root.s_in, _conj(root.w, d))
19+
s_in = root.s_in
20+
s_out = root.s_out
21+
N = root.sz
22+
w = _conj(root.w, d)
23+
if t === DFT
24+
fft_dft!(out, in, N, start_out, s_out, start_in, s_in, w)
25+
elseif t === POW2RADIX4_FFT
26+
fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, w)
27+
elseif t === POW3_FFT
28+
_m_120 = cispi(T(2) / 3)
29+
m_120 = d === FFT_FORWARD ? _m_120 : conj(_m_120)
30+
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, w, m_120)
31+
elseif t === BLUESTEIN
32+
fft_bluestein!(out, in, d, N, start_out, s_out, start_in, s_in)
1433
else
15-
s_in = root.s_in
16-
s_out = root.s_out
17-
if t === POW2RADIX4_FFT
18-
fft_pow2_radix4!(out, in, root.sz, start_out, s_out, start_in, s_in, _conj(root.w, d))
19-
elseif t === POW3_FFT
20-
p_120 = cispi(T(2)/3)
21-
m_120 = cispi(T(4)/3)
22-
_p_120, _m_120 = d == FFT_FORWARD ? (p_120, m_120) : (m_120, p_120)
23-
fft_pow3!(out, in, root.sz, start_out, s_out, start_in, s_in, _conj(root.w, d), _m_120, _p_120)
24-
else
25-
throw(ArgumentError("kernel not implemented"))
26-
end
34+
throw(ArgumentError("kernel not implemented"))
2735
end
2836
end
2937
end
@@ -49,27 +57,58 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out
4957
right_idx = idx + root.right
5058
left = g[left_idx]
5159
right = g[right_idx]
52-
N = root.sz
60+
# N = root.sz
5361
N1 = left.sz
5462
N2 = right.sz
5563
s_in = root.s_in
5664
s_out = root.s_out
5765

66+
Rt = right.type
67+
Lt = left.type
68+
5869
w1 = _conj(root.w, d)
5970
wj1 = one(T)
6071
tmp = g.workspace[idx]
61-
@inbounds for j1 in 0:N1-1
72+
73+
if Rt === BLUESTEIN
74+
R_scratch = prealloc_blue(N2, d, T)
75+
end
76+
for j1 in 0:N1-1
6277
wk2 = wj1
63-
fft!(tmp, in, N2*j1+1, start_in + j1*s_in, d, right.type, g, right_idx)
64-
j1 > 0 && @inbounds for k2 in 1:N2-1
65-
tmp[N2*j1 + k2 + 1] *= wk2
66-
wk2 *= wj1
78+
R_start_in = start_in + j1 * s_in
79+
R_start_out = 1 + N2 * j1
80+
81+
if Rt === BLUESTEIN
82+
R_s_in = right.s_in
83+
R_s_out = right.s_out
84+
fft_bluestein!(tmp, in, d, N2, R_start_out, R_s_out, R_start_in, R_s_in, R_scratch)
85+
else
86+
fft!(tmp, in, R_start_out, R_start_in, d, Rt, g, right_idx)
6787
end
88+
89+
if j1 > 0
90+
@inbounds for k2 in 1:N2-1
91+
tmp[R_start_out + k2] *= wk2
92+
wk2 *= wj1
93+
end
94+
end
95+
6896
wj1 *= w1
6997
end
7098

71-
@inbounds for k2 in 0:N2-1
72-
fft!(out, tmp, start_out + k2*s_out, k2+1, d, left.type, g, left_idx)
99+
if Lt === BLUESTEIN
100+
L_scratch = prealloc_blue(N1, d, T)
101+
end
102+
for k2 in 0:N2-1
103+
L_start_out = start_out + k2 * s_out
104+
L_start_in = 1 + k2
105+
if Lt === BLUESTEIN
106+
L_s_in = left.s_in
107+
L_s_out = left.s_out
108+
fft_bluestein!(out, tmp, d, N1, L_start_out, L_s_out, L_start_in, L_s_in, L_scratch)
109+
else
110+
fft!(out, tmp, L_start_out, L_start_in, d, Lt, g, left_idx)
111+
end
73112
end
74113
end
75114

@@ -178,7 +217,7 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int,
178217
w1 = w
179218
w2 = w * w1
180219
w3 = w * w2
181-
w4 = w * w3
220+
w4 = w2 * w2
182221

183222
fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w4)
184223
fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w4)
@@ -228,7 +267,8 @@ Power of 3 FFT, in place
228267
- `minus120`: Depending on direction, perform either ∓120° rotation
229268
230269
"""
231-
function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T, plus120::T, minus120::T) where {T, U}
270+
function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T, minus120::T) where {T, U}
271+
plus120 = conj(minus120)
232272
if N == 3
233273
@muladd out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in]
234274
@muladd out[start_out + stride_out] = in[start_in] + in[start_in + stride_in]*plus120 + in[start_in + 2*stride_in]*minus120
@@ -240,17 +280,17 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
240280
Nprime = N ÷ 3
241281

242282
# Dividing into subproblems
243-
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, w^3, plus120, minus120)
244-
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, w^3, plus120, minus120)
245-
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, w^3, plus120, minus120)
283+
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, w^3, minus120)
284+
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, w^3, minus120)
285+
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, w^3, minus120)
246286

247287
w1 = w
248288
w2 = w * w1
249289
wk1 = wk2 = one(T)
250290
for k in 0:Nprime-1
251-
@muladd k0 = start_out + stride_out * k
252-
@muladd k1 = start_out + stride_out * (k + Nprime)
253-
@muladd k2 = start_out + stride_out * (k + 2 * Nprime)
291+
k0 = start_out + stride_out * k
292+
k1 = start_out + stride_out * (k + Nprime)
293+
k2 = start_out + stride_out * (k + 2 * Nprime)
254294
y_k0, y_k1, y_k2 = out[k0], out[k1], out[k2]
255295
@muladd out[k0] = y_k0 + y_k1 * wk1 + y_k2 * wk2
256296
@muladd out[k1] = y_k0 + y_k1 * wk1 * plus120 + y_k2 * wk2 * minus120
@@ -259,3 +299,87 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
259299
wk2 *= w2
260300
end
261301
end
302+
303+
304+
function prealloc_blue(N::Int, d::Direction, ::Type{T}) where T<:Number
305+
pad_len = nextpow(2, 2N - 1)
306+
307+
b_series = Vector{T}(undef, pad_len)
308+
a_series = Vector{T}(undef, pad_len)
309+
tmp = Vector{T}(undef, pad_len)
310+
311+
b_series[N+1:end] .= zero(T)
312+
313+
sgn = -direction_sign(d)
314+
p = 0 # n^2
315+
for i in 1:N
316+
b_series[i] = cispi(sgn * p / N)
317+
p += (2i - 1) # prevents overflow unless N is absolutely massive
318+
p > N && (p -= 2N)
319+
end
320+
321+
# enforce periodic boundaries for b_n
322+
for j in 0:N-1
323+
b_series[pad_len-j] = b_series[2+j]
324+
end
325+
326+
return (tmp, a_series, b_series, pad_len)
327+
end
328+
329+
"""
330+
$(TYPEDSIGNATURES)
331+
Bluestein's algorithm, still O(N * log(N)) for large primes,
332+
but with a big constant factor.
333+
Zero-pads two sequences derived from the DFT formula to a
334+
power of 2 length greater than `2N-1` and computes their convolution
335+
with a power 2 FFT.
336+
337+
# Arguments
338+
- `out`: Output vector
339+
- `in`: Input vector
340+
- `d`: Direction of the transform
341+
- `N`: Size of the transform
342+
- `start_out`: Index of the first element of the output vector
343+
- `stride_out`: Stride of the output vector
344+
- `start_in`: Index of the first element of the input vector
345+
- `stride_in`: Stride of the input vector
346+
- `w`: The value `cispi(direction_sign(d) * 2 / N)`
347+
348+
"""
349+
function fft_bluestein!(
350+
out::AbstractVector{T}, in::AbstractVector{T},
351+
d::Direction,
352+
N::Int,
353+
start_out::Int, stride_out::Int,
354+
start_in::Int, stride_in::Int,
355+
scratch::Tuple{Vector{T},Vector{T},Vector{T},Int}=prealloc_blue(N, d, T)
356+
) where T<:Number
357+
358+
(tmp, a_series, b_series, pad_len) = scratch
359+
360+
a_series[N+1:end] .= zero(T)
361+
tmp[N+1:end] .= zero(T)
362+
363+
for i in 1:N
364+
a_series[i] = in[start_in+(i-1)*stride_in] * conj(b_series[i])
365+
end
366+
367+
w_pad = cispi(T(2) / pad_len)
368+
# leave b_n vector alone for last step
369+
fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, w_pad) # Fa
370+
fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, w_pad) # Fb
371+
372+
tmp .*= a_series
373+
# convolution theorem ifft
374+
fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, conj(w_pad))
375+
conv_a_b = a_series
376+
377+
Xk = tmp
378+
for i in 1:N
379+
Xk[i] = conj(b_series[i]) * conv_a_b[i] / pad_len
380+
end
381+
382+
out_inds = range(start_out; step=stride_out, length=N)
383+
copyto!(out, CartesianIndices((out_inds,)), Xk, CartesianIndices((N,)))
384+
return nothing
385+
end

src/callgraph.jl

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@enum Direction FFT_FORWARD=-1 FFT_BACKWARD=1
2-
@enum Pow24 POW2=2 POW4=1
3-
@enum FFTEnum COMPOSITE_FFT DFT POW3_FFT POW2RADIX4_FFT
2+
@enum Pow24 POW2 POW4
3+
@enum FFTEnum COMPOSITE_FFT DFT POW3_FFT POW2RADIX4_FFT BLUESTEIN
44

55
"""
66
$(TYPEDEF)
@@ -30,13 +30,18 @@ Object representing a graph of FFT Calls
3030
# Arguments
3131
- `nodes`: Nodes keeping track of the graph
3232
- `workspace`: Preallocated Workspace
33+
- `BLUESTEIN_CUTOFF`: Minimum prime that will be FFTed with the
34+
Bluestein algorithm, below which the O(N^2) DFT is used.
3335
3436
"""
3537
struct CallGraph{T<:Complex}
3638
nodes::Vector{CallGraphNode{T}}
3739
workspace::Vector{Vector{T}}
40+
BLUESTEIN_CUTOFF::Int
3841
end
3942

43+
const DEFAULT_BLUESTEIN_CUTOFF = 73
44+
4045
# Get the node in the graph at index i
4146
Base.getindex(g::CallGraph{T}, i::Int) where {T} = g.nodes[i]
4247

@@ -45,13 +50,13 @@ $(TYPEDSIGNATURES)
4550
Check if `N` is a power of 2 or 4
4651
4752
"""
48-
function _ispow24(N::Int)
49-
if ispow2(N)
50-
zero_cnt = trailing_zeros(N)
51-
return iseven(zero_cnt) ? POW4 : POW2
52-
end
53-
return nothing
54-
end
53+
# function _ispow24(N::Int)
54+
# if ispow2(N)
55+
# zero_cnt = trailing_zeros(N)
56+
# return iseven(zero_cnt) ? POW4 : POW2
57+
# end
58+
# return nothing
59+
# end
5560

5661
"""
5762
$(TYPEDSIGNATURES)
@@ -65,25 +70,30 @@ Recursively instantiate a set of `CallGraphNode`s
6570
- `s_out`: The stride of the output
6671
6772
"""
68-
function CallGraphNode!(nodes::Vector{CallGraphNode{T}}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
73+
function CallGraphNode!(
74+
nodes::Vector{CallGraphNode{T}},
75+
N::Int,
76+
workspace::Vector{Vector{T}},
77+
BLUESTEIN_CUTOFF::Int,
78+
s_in::Int, s_out::Int)::Int where {T}
6979
if N <= 0
7080
throw(DimensionMismatch("Array length must be strictly positive"))
7181
end
7282
w = cispi(T(2) / N)
73-
if iseven(N)
74-
pow = _ispow24(N)
75-
if !isnothing(pow)
76-
push!(workspace, T[])
77-
push!(nodes, CallGraphNode(0, 0, POW2RADIX4_FFT, N, s_in, s_out, w))
78-
return 1
79-
end
83+
if iseven(N) && ispow2(N)
84+
# _ispow24(N)
85+
push!(workspace, T[])
86+
push!(nodes, CallGraphNode(0, 0, POW2RADIX4_FFT, N, s_in, s_out, w))
87+
return 1
8088
elseif N % 3 == 0 && nextpow(3, N) == N
8189
push!(workspace, T[])
8290
push!(nodes, CallGraphNode(0, 0, POW3_FFT, N, s_in, s_out, w))
8391
return 1
8492
elseif N == 1 || Primes.isprime(N)
8593
push!(workspace, T[])
86-
push!(nodes, CallGraphNode(0, 0, DFT, N, s_in, s_out, w))
94+
# use Bluestein's algorithm for big primes
95+
LEAF_ALG = N < BLUESTEIN_CUTOFF ? DFT : BLUESTEIN
96+
push!(nodes, CallGraphNode(0, 0, LEAF_ALG, N, s_in, s_out, w))
8797
return 1
8898
end
8999
fzn = Primes.factor(N)
@@ -98,16 +108,16 @@ function CallGraphNode!(nodes::Vector{CallGraphNode{T}}, N::Int, workspace::Vect
98108
N_cp = cumprod(Ns) # reverse(Ns) another choice
99109
N1_idx = searchsortedlast(N_cp, N_isqrt)
100110
N1 = N_cp[N1_idx] # N1 <= N_isqrt <= N_fsqrt
101-
if N1_idx != lastindex(N_cp) && (abs(N_cp[N1_idx+1] - N_fsqrt) < (N_fsqrt - N1))
111+
if N1_idx != lastindex(N_cp) && (N_cp[N1_idx+1] - N_fsqrt < (N_fsqrt - N1))
102112
N1 = N_cp[N1_idx+1] # can be >= N_fsqrt
103113
end
104114
end
105115
N2 = N ÷ N1
106116
push!(nodes, CallGraphNode(0, 0, DFT, N, s_in, s_out, w))
107117
sz = length(nodes)
108118
push!(workspace, Vector{T}(undef, N))
109-
left_len = CallGraphNode!(nodes, N1, workspace, N2 , N2 * s_out)
110-
right_len = CallGraphNode!(nodes, N2, workspace, N1 * s_in, 1)
119+
left_len = CallGraphNode!(nodes, N1, workspace, BLUESTEIN_CUTOFF, N2 , N2 * s_out)
120+
right_len = CallGraphNode!(nodes, N2, workspace, BLUESTEIN_CUTOFF, N1 * s_in, 1)
111121
nodes[sz] = CallGraphNode(1, 1 + left_len, COMPOSITE_FFT, N, s_in, s_out, w)
112122
return 1 + left_len + right_len
113123
end
@@ -117,9 +127,9 @@ $(TYPEDSIGNATURES)
117127
Instantiate a CallGraph from a number `N`
118128
119129
"""
120-
function CallGraph{T}(N::Int) where {T}
130+
function CallGraph{T}(N::Int, BLUESTEIN_CUTOFF::Int) where {T}
121131
nodes = CallGraphNode{T}[]
122132
workspace = Vector{Vector{T}}()
123-
CallGraphNode!(nodes, N, workspace, 1, 1)
124-
CallGraph(nodes, workspace)
133+
CallGraphNode!(nodes, N, workspace, BLUESTEIN_CUTOFF, 1, 1)
134+
CallGraph(nodes, workspace, BLUESTEIN_CUTOFF)
125135
end

0 commit comments

Comments
 (0)