Skip to content

Commit ac171b5

Browse files
authored
Merge pull request #60 from JuliaMath/an/precomputew
Precompute w to avoid recomputing it over and over
2 parents a2e78ed + ece8dde commit ac171b5

File tree

3 files changed

+48
-51
lines changed

3 files changed

+48
-51
lines changed

src/algos.jl

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ fft!(::AbstractVector{T}, ::AbstractVector{T}, ::Int, ::Int, ::Direction, ::Abst
44
Int(d)
55
end
66

7+
@inline _conj(w::Complex, d::Direction) = ifelse(direction_sign(d) === 1, w, conj(w))
8+
79
function (g::CallGraph{T})(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, v::Direction, t::FFTEnum, idx::Int) where {T,U}
810
fft!(out, in, start_out, start_in, v, t, g, idx)
911
end
@@ -52,7 +54,7 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, sta
5254
s_in = root.s_in
5355
s_out = root.s_out
5456

55-
w1 = convert(T, cispi(direction_sign(d)*2/N))
57+
w1 = _conj(root.w, d)
5658
wj1 = one(T)
5759
tmp = g.workspace[idx]
5860
@inbounds for j1 in 0:N1-1
@@ -82,17 +84,17 @@ Discrete Fourier Transform, O(N^2) algorithm, in place.
8284
`stride_out`: Stride of the output vector
8385
`start_in`: Index of the first element of the input vector
8486
`stride_in`: Stride of the input vector
85-
`d`: Direction of the transform
87+
`w`: The value `cispi(direction_sign(d) * 2 / N)`
8688
8789
"""
88-
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T}
90+
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T}
8991
tmp = in[start_in]
9092
@inbounds for j in 1:N-1
9193
tmp += in[start_in + j*stride_in]
9294
end
9395
out[start_out] = tmp
9496

95-
wk = wkn = w = convert(T, cispi(direction_sign(d)*2/N))
97+
wk = wkn = w
9698
@inbounds for d in 1:N-1
9799
tmp = in[start_in]
98100
@inbounds for k in 1:N-1
@@ -105,7 +107,7 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_o
105107
end
106108
end
107109

108-
function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T<:Real}
110+
function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::Complex{T}) where {T<:Real}
109111
halfN = N÷2
110112

111113
tmp = Complex{T}(in[start_in])
@@ -114,7 +116,7 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int
114116
end
115117
out[start_out] = tmp
116118

117-
wk = wkn = w = convert(Complex{T}, cispi(direction_sign(d)*2/N))
119+
wk = wkn = w
118120
@inbounds for d in 1:halfN
119121
tmp = Complex{T}(in[start_in])
120122
@inbounds for k in 1:N-1
@@ -129,7 +131,7 @@ end
129131

130132
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::DFT, g::CallGraph{T}, idx::Int) where {T,U}
131133
root = g[idx]
132-
fft_dft!(out, in, root.sz, start_out, root.s_out, start_in, root.s_in, d)
134+
fft_dft!(out, in, root.sz, start_out, root.s_out, start_in, root.s_in, _conj(root.w, d))
133135
end
134136

135137
"""
@@ -144,29 +146,28 @@ Power of 2 FFT, in place
144146
`stride_out`: Stride of the output vector
145147
`start_in`: Index of the first element of the input vector
146148
`stride_in`: Stride of the input vector
147-
`d`: Direction of the transform
149+
`w`: The value `cispi(direction_sign(d) * 2 / N)`
148150
149151
"""
150-
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T, U}
152+
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U}
151153
if N == 2
152154
out[start_out] = in[start_in] + in[start_in + stride_in]
153155
out[start_out + stride_out] = in[start_in] - in[start_in + stride_in]
154156
return
155157
end
156158
m = N ÷ 2
157159

158-
fft_pow2!(out, in, m, start_out , stride_out, start_in , stride_in*2, d)
159-
fft_pow2!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*2, d)
160+
fft_pow2!(out, in, m, start_out , stride_out, start_in , stride_in*2, w*w)
161+
fft_pow2!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*2, w*w)
160162

161-
w1 = convert(T, cispi(direction_sign(d)*2/N))
162163
wj = one(T)
163164
@inbounds for j in 0:m-1
164165
j1_out = start_out + j*stride_out
165166
j2_out = start_out + (j+m)*stride_out
166167
out_j = out[j1_out]
167168
out[j1_out] = out_j + wj*out[j2_out]
168169
out[j2_out] = out_j - wj*out[j2_out]
169-
wj *= w1
170+
wj *= w
170171
end
171172
end
172173

@@ -175,7 +176,7 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, sta
175176
N = root.sz
176177
s_in = root.s_in
177178
s_out = root.s_out
178-
fft_pow2!(out, in, N, start_out, s_out, start_in, s_in, d)
179+
fft_pow2!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d))
179180
end
180181

181182
"""
@@ -190,13 +191,12 @@ Power of 4 FFT, in place
190191
`stride_out`: Stride of the output vector
191192
`start_in`: Index of the first element of the input vector
192193
`stride_in`: Stride of the input vector
193-
`d`: Direction of the transform
194+
`w`: The value `cispi(direction_sign(d) * 2 / N)`
194195
195196
"""
196-
function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T, U}
197-
ds = direction_sign(d)
198-
plusi = ds*1im
199-
minusi = ds*-1im
197+
function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U}
198+
plusi = sign(imag(w))*im
199+
minusi = -sign(imag(w))*im
200200
if N == 4
201201
out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in] + in[start_in + 3*stride_in]
202202
out[start_out + stride_out] = in[start_in] + in[start_in + stride_in]*plusi - in[start_in + 2*stride_in] + in[start_in + 3*stride_in]*minusi
@@ -206,17 +206,14 @@ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
206206
end
207207
m = N ÷ 4
208208

209-
@muladd fft_pow4!(out, in, m, start_out , stride_out, start_in , stride_in*4, d)
210-
@muladd fft_pow4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, d)
211-
@muladd fft_pow4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, d)
212-
@muladd fft_pow4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, d)
209+
@muladd fft_pow4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w^4)
210+
@muladd fft_pow4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w^4)
211+
@muladd fft_pow4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w^4)
212+
@muladd fft_pow4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w^4)
213213

214-
w1 = convert(T, cispi(direction_sign(d)*2/N))
215-
wj = one(T)
216-
217-
w1 = convert(T, cispi(ds*2/N))
218-
w2 = convert(T, cispi(ds*4/N))
219-
w3 = convert(T, cispi(ds*6/N))
214+
w1 = w
215+
w2 = w*w1
216+
w3 = w*w2
220217
wk1 = wk2 = wk3 = one(T)
221218

222219
@inbounds for k in 0:m-1
@@ -240,7 +237,7 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, sta
240237
N = root.sz
241238
s_in = root.s_in
242239
s_out = root.s_out
243-
fft_pow4!(out, in, N, start_out, s_out, start_in, s_in, d)
240+
fft_pow4!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d))
244241
end
245242

246243
"""
@@ -255,12 +252,12 @@ start_out: Index of the first element of the output vector
255252
stride_out: Stride of the output vector
256253
start_in: Index of the first element of the input vector
257254
stride_in: Stride of the input vector
258-
d: Direction of the transform
255+
w: The value `cispi(direction_sign(d) * 2 / N)`
259256
plus120: Depending on direction, perform either ±120° rotation
260257
minus120: Depending on direction, perform either ∓120° rotation
261258
262259
"""
263-
function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction, plus120::T, minus120::T) where {T, U}
260+
function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T, plus120::T, minus120::T) where {T, U}
264261
if N == 3
265262
@muladd out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in]
266263
@muladd out[start_out + stride_out] = in[start_in] + in[start_in + stride_in]*plus120 + in[start_in + 2*stride_in]*minus120
@@ -271,15 +268,13 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
271268
# Size of subproblem
272269
Nprime = N ÷ 3
273270

274-
ds = direction_sign(d)
275-
276271
# Dividing into subproblems
277-
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, d, plus120, minus120)
278-
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, d, plus120, minus120)
279-
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, d, plus120, minus120)
272+
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, w^3, plus120, minus120)
273+
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, w^3, plus120, minus120)
274+
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, w^3, plus120, minus120)
280275

281-
w1 = convert(T, cispi(ds*2/N))
282-
w2 = convert(T, cispi(ds*4/N))
276+
w1 = w
277+
w2 = w*w1
283278
wk1 = wk2 = one(T)
284279
for k in 0:Nprime-1
285280
@muladd k0 = start_out + k*stride_out
@@ -302,8 +297,8 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, sta
302297
p_120 = convert(T, cispi(2/3))
303298
m_120 = convert(T, cispi(4/3))
304299
if d == FFT_FORWARD
305-
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, d, m_120, p_120)
300+
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d), m_120, p_120)
306301
else
307-
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, d, p_120, m_120)
302+
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d), p_120, m_120)
308303
end
309304
end

src/callgraph.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ Node of a call graph
3232
`sz`: Size of this FFT
3333
3434
"""
35-
struct CallGraphNode
35+
struct CallGraphNode{T}
3636
left::Int
3737
right::Int
3838
type::FFTEnum
3939
sz::Int
4040
s_in::Int
4141
s_out::Int
42+
w::T
4243
end
4344

4445
"""
@@ -51,7 +52,7 @@ Object representing a graph of FFT Calls
5152
5253
"""
5354
struct CallGraph{T<:Complex}
54-
nodes::Vector{CallGraphNode}
55+
nodes::Vector{CallGraphNode{T}}
5556
workspace::Vector{Vector{T}}
5657
end
5758

@@ -83,28 +84,29 @@ Recursively instantiate a set of `CallGraphNode`s
8384
`s_out`: The stride of the output
8485
8586
"""
86-
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
87+
function CallGraphNode!(nodes::Vector{CallGraphNode{T}}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
8788
if N == 0
8889
throw(DimensionMismatch("array has to be non-empty"))
8990
end
91+
w = convert(T, cispi(2/N))
9092
if iseven(N)
9193
pow = _ispow24(N)
9294
if !isnothing(pow)
9395
push!(workspace, T[])
94-
push!(nodes, CallGraphNode(0, 0, pow == POW2 ? pow2FFT : pow4FFT, N, s_in, s_out))
96+
push!(nodes, CallGraphNode(0, 0, pow == POW2 ? pow2FFT : pow4FFT, N, s_in, s_out, w))
9597
return 1
9698
end
9799
end
98100
if N % 3 == 0
99101
if nextpow(3, N) == N
100102
push!(workspace, T[])
101-
push!(nodes, CallGraphNode(0, 0, pow3FFT, N, s_in, s_out))
103+
push!(nodes, CallGraphNode(0, 0, pow3FFT, N, s_in, s_out, w))
102104
return 1
103105
end
104106
end
105107
if N == 1 || isprime(N)
106108
push!(workspace, T[])
107-
push!(nodes, CallGraphNode(0, 0, dft, N, s_in, s_out))
109+
push!(nodes, CallGraphNode(0, 0, dft, N, s_in, s_out, w))
108110
return 1
109111
end
110112
Ns = [first(x) for x in collect(factor(N)) for _ in 1:last(x)]
@@ -121,12 +123,12 @@ function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{
121123
N1 = N_cp[N1_idx]
122124
end
123125
N2 = N ÷ N1
124-
push!(nodes, CallGraphNode(0, 0, dft, N, s_in, s_out))
126+
push!(nodes, CallGraphNode(0, 0, dft, N, s_in, s_out, w))
125127
sz = length(nodes)
126128
push!(workspace, Vector{T}(undef, N))
127129
left_len = CallGraphNode!(nodes, N1, workspace, N2, N2*s_out)
128130
right_len = CallGraphNode!(nodes, N2, workspace, N1*s_in, 1)
129-
nodes[sz] = CallGraphNode(1, 1 + left_len, compositeFFT, N, s_in, s_out)
131+
nodes[sz] = CallGraphNode(1, 1 + left_len, compositeFFT, N, s_in, s_out, w)
130132
return 1 + left_len + right_len
131133
end
132134

@@ -136,7 +138,7 @@ Instantiate a CallGraph from a number `N`
136138
137139
"""
138140
function CallGraph{T}(N::Int) where {T}
139-
nodes = CallGraphNode[]
141+
nodes = CallGraphNode{T}[]
140142
workspace = Vector{Vector{T}}()
141143
CallGraphNode!(nodes, N, workspace, 1, 1)
142144
CallGraph(nodes, workspace)

test/onedim/real_forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717

1818
@testset "temporarily test real dft separately until used by rfft" begin
1919
y_dft = similar(y)
20-
FFTA.fft_dft!(y_dft, x, n, 1, 1, 1, 1, FFTA.FFT_FORWARD)
20+
FFTA.fft_dft!(y_dft, x, n, 1, 1, 1, 1, cispi(-2/n))
2121
@test y y_dft
2222
end
2323
end

0 commit comments

Comments
 (0)