Skip to content

Commit c032015

Browse files
authored
Merge pull request #9 from dannys4/ds/dev
Better typing, support for reals
2 parents f00bfea + d3a1f8b commit c032015

File tree

4 files changed

+153
-45
lines changed

4 files changed

+153
-45
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
8+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
89
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
910

1011
[compat]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# FFTA: Fastest Fourier Transform in my Apartment
22
## A library by Danny Sharp
33

4-
This is a *pure Julia* implementation of FFTs, with the goal that this could supplant other FFTs for applications that require odd Julia objects. Currently this supports `AbstractArray{T,N}` for `T<:Complex` and `N` in `{1,2}` (i.e. `AbstractVector` and `AbstractMatrix`). If you're looking for more performance, checkout `FFTW.jl`.
4+
This is a *pure Julia* implementation of FFTs, with the goal that this could supplant other FFTs for applications that require odd Julia objects. Currently this supports `AbstractArray{T,N}` where `N` in `{1,2}` (i.e. `AbstractVector` and `AbstractMatrix`). If you're looking for more performance, checkout `FFTW.jl`. The only functions that need to be defined with `T` (besides arithmetic) are `convert(T, x::ComplexF64)` and `one(T)`. This means that `T<:Real` probably doesn't work yet (see Path Forward).
55

66
Path Forward:
77
- Dispatch on `Real`

src/FFTA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module FFTA
22

3-
using Primes, DocStringExtensions
3+
using Primes, DocStringExtensions, LoopVectorization
44
import Base: getindex
55
export fft, bfft
66

src/algos.jl

Lines changed: 150 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
@enum Direction FFT_FORWARD FFT_BACKWARD
22
abstract type AbstractFFTType end
33

4+
function alternatingSum(x::AbstractVector{T}) where T
5+
y = x[1]
6+
@turbo for i in 2:length(x)
7+
y += (x[i] * convert(T,(2 * (i % 2) - 1)))
8+
end
9+
y
10+
end
11+
412
# Represents a Composite Cooley-Tukey FFT
513
struct CompositeFFT <: AbstractFFTType end
614

@@ -51,7 +59,7 @@ struct CallGraph{T<:Complex}
5159
end
5260

5361
# Get the node in the graph at index i
54-
Base.getindex(g::CallGraph{T}, i::Int) where {T<:Complex} = g.nodes[i]
62+
Base.getindex(g::CallGraph{T}, i::Int) where {T} = g.nodes[i]
5563

5664
# Get the left child of the node at index `i`
5765
leftNode(g::CallGraph, i::Int) = g[i+g[i].left]
@@ -60,7 +68,7 @@ leftNode(g::CallGraph, i::Int) = g[i+g[i].left]
6068
rightNode(g::CallGraph, i::Int) = g[i+g[i].right]
6169

6270
# Recursively instantiate a set of `CallGraphNode`s
63-
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}})::Int where {T<:Complex}
71+
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}})::Int where {T}
6472
facs = factor(N)
6573
Ns = [first(x) for x in collect(facs) for _ in 1:last(x)]
6674
if length(Ns) == 1 || Ns[end] == 2
@@ -90,21 +98,28 @@ function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{
9098
end
9199

92100
# Instantiate a CallGraph from a number `N`
93-
function CallGraph{T}(N::Int) where {T<:Complex}
101+
function CallGraph{T}(N::Int) where {T}
94102
nodes = CallGraphNode[]
95103
workspace = Vector{Vector{T}}()
96104
CallGraphNode!(nodes, N, workspace)
97105
CallGraph(nodes, workspace)
98106
end
99107

100-
function fft(x::AbstractVector{T}) where {T<:Complex}
108+
function fft(x::AbstractVector{T}) where {T}
101109
y = similar(x)
102110
g = CallGraph{T}(length(x))
103111
fft!(y, x, Val(FFT_FORWARD), g[1].type, g, 1)
104112
y
105113
end
106114

107-
function fft(x::AbstractMatrix{T}) where {T<:Complex}
115+
function fft(x::AbstractVector{T}) where {T <: Real}
116+
y = similar(x, Complex{T})
117+
g = CallGraph{Complex{T}}(length(x))
118+
fft!(y, x, Val(FFT_FORWARD), g[1].type, g, 1)
119+
y
120+
end
121+
122+
function fft(x::AbstractMatrix{T}) where {T}
108123
M,N = size(x)
109124
y1 = similar(x)
110125
y2 = similar(x)
@@ -121,14 +136,14 @@ function fft(x::AbstractMatrix{T}) where {T<:Complex}
121136
y2
122137
end
123138

124-
function bfft(x::AbstractVector{T}) where {T<:Complex}
139+
function bfft(x::AbstractVector{T}) where {T}
125140
y = similar(x)
126141
g = CallGraph{T}(length(x))
127142
fft!(y, x, Val(FFT_BACKWARD), g[1].type, g, 1)
128143
y
129144
end
130145

131-
function bfft(x::AbstractMatrix{T}) where {T<:Complex}
146+
function bfft(x::AbstractMatrix{T}) where {T}
132147
M,N = size(x)
133148
y1 = similar(x)
134149
y2 = similar(x)
@@ -147,24 +162,23 @@ end
147162

148163
fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{<:Direction}, ::AbstractFFTType, ::CallGraph{T}, ::Int) where {T} = nothing
149164

150-
function (g::CallGraph{T})(out::AbstractVector{T}, in::AbstractVector{T}, v::Val{FFT_FORWARD}, t::AbstractFFTType, idx::Int) where {T}
165+
function (g::CallGraph{T})(out::AbstractVector{T}, in::AbstractVector{U}, v::Val{FFT_FORWARD}, t::AbstractFFTType, idx::Int) where {T,U}
151166
fft!(out, in, v, t, g, idx)
152167
end
153168

154-
function (g::CallGraph{T})(out::AbstractVector{T}, in::AbstractVector{T}, v::Val{FFT_BACKWARD}, t::AbstractFFTType, idx::Int) where {T}
169+
function (g::CallGraph{T})(out::AbstractVector{T}, in::AbstractVector{U}, v::Val{FFT_BACKWARD}, t::AbstractFFTType, idx::Int) where {T,U}
155170
fft!(out, in, v, t, g, idx)
156171
end
157172

158-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}, ::CompositeFFT, g::CallGraph{T}, idx::Int) where {T<:Complex}
173+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, ::Val{FFT_FORWARD}, ::CompositeFFT, g::CallGraph{T}, idx::Int) where {T,U}
159174
N = length(out)
160175
left = leftNode(g,idx)
161176
right = rightNode(g,idx)
162177
N1 = left.sz
163178
N2 = right.sz
164179

165-
inc = 2*π/N
166-
w1 = T(cos(inc), -sin(inc))
167-
wj1 = T(1, 0)
180+
w1 = convert(T, cispi(-2/N))
181+
wj1 = one(T)
168182
tmp = g.workspace[idx]
169183
for j1 in 1:N1
170184
wk2 = wj1;
@@ -181,16 +195,15 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD},
181195
end
182196
end
183197

184-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}, ::CompositeFFT, g::CallGraph{T}, idx::Int) where {T<:Complex}
198+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, ::Val{FFT_BACKWARD}, ::CompositeFFT, g::CallGraph{T}, idx::Int) where {T,U}
185199
N = length(out)
186200
left = left(g,i)
187201
right = right(g,i)
188202
N1 = left.sz
189203
N2 = right.sz
190204

191-
inc = 2*π/N
192-
w1 = T(cos(inc), sin(inc))
193-
wj1 = T(1, 0)
205+
w1 = convert(T, cispi(2/N))
206+
wj1 = one(T)
194207
tmp = g.workspace[idx]
195208
for j1 in 2:N1
196209
Complex<F,L> wk2 = wj1;
@@ -207,19 +220,19 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}
207220
end
208221
end
209222

210-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}, ::Pow2FFT, ::CallGraph{T}, ::Int) where {T<:Complex}
223+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, ::Val{FFT_FORWARD}, ::Pow2FFT, ::CallGraph{T}, ::Int) where {T,U}
211224
fft_pow2!(out, in, Val(FFT_FORWARD))
212225
end
213226

214-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}, ::Pow2FFT, ::CallGraph{T}, ::Int) where {T<:Complex}
227+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, ::Val{FFT_BACKWARD}, ::Pow2FFT, ::CallGraph{T}, ::Int) where {T,U}
215228
fft_pow2!(out, in, Val(FFT_BACKWARD))
216229
end
217230

218231
"""
219232
Power of 2 FFT in place, forward
220233
221234
"""
222-
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}) where {T<:Complex}
235+
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}) where {T}
223236
N = length(out)
224237
if N == 1
225238
out[1] = in[1]
@@ -228,9 +241,8 @@ function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORW
228241
fft_pow2!(@view(out[1:(end÷2)]), @view(in[1:2:end]), Val(FFT_FORWARD))
229242
fft_pow2!(@view(out[(end÷2+1):end]), @view(in[2:2:end]), Val(FFT_FORWARD))
230243

231-
inc = 2*π/N
232-
w1 = T(cos(inc), -sin(inc));
233-
wj = T(1,0)
244+
w1 = convert(T, cispi(-2/N))
245+
wj = one(T)
234246
m = N ÷ 2
235247
for j in 1:m
236248
out_j = out[j]
@@ -244,7 +256,7 @@ end
244256
Power of 2 FFT in place, backward
245257
246258
"""
247-
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}) where {T<:Complex}
259+
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}) where {T}
248260
N = length(out)
249261
if N == 1
250262
out[1] = in[1]
@@ -253,9 +265,8 @@ function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACK
253265
fft_pow2!(@view(out[1:(end÷2)]), @view(in[1:2:end]), Val(FFT_BACKWARD))
254266
fft_pow2!(@view(out[(end÷2+1):end]), @view(in[2:2:end]), Val(FFT_BACKWARD))
255267

256-
inc = 2*π/N
257-
w1 = T(cos(inc), sin(inc));
258-
wj = T(1,0)
268+
w1 = convert(T, cispi(2/N))
269+
wj = one(T)
259270
m = N ÷ 2
260271
for j in 1:m
261272
out_j = out[j]
@@ -265,16 +276,63 @@ function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACK
265276
end
266277
end
267278

268-
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}) where {T<:Complex}
279+
"""
280+
Power of 2 FFT in place, forward
281+
282+
"""
283+
function fft_pow2!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val{FFT_FORWARD}) where {T<:Real}
269284
N = length(out)
270-
inc = 2*π/N
271-
wn² = wn = w = T(cos(inc), sin(inc))
272-
wn_1 = T(1., 0.)
285+
if N == 1
286+
out[1] = in[1]
287+
return
288+
end
289+
fft_pow2!(@view(out[1:(end÷2)]), @view(in[1:2:end]), Val(FFT_FORWARD))
290+
fft_pow2!(@view(out[(end÷2+1):end]), @view(in[2:2:end]), Val(FFT_FORWARD))
273291

274-
tmp = in[1]
275-
out .= tmp
292+
w1 = convert(Complex{T}, cispi(-2/N))
293+
wj = one(Complex{T})
294+
m = N ÷ 2
295+
@turbo for j in 2:m
296+
out[j] = out[j] + wj*out[j+m]
297+
wj *= w1
298+
end
299+
@turbo for j in 2:m
300+
out[m+j] = conj(out[m-j+2])
301+
end
302+
end
303+
304+
"""
305+
Power of 2 FFT in place, backward
306+
307+
"""
308+
function fft_pow2!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}) where {T<:Real}
309+
N = length(out)
310+
if N == 1
311+
out[1] = in[1]
312+
return
313+
end
314+
fft_pow2!(@view(out[1:(end÷2)]), @view(in[1:2:end]), Val(FFT_BACKWARD))
315+
fft_pow2!(@view(out[(end÷2+1):end]), @view(in[2:2:end]), Val(FFT_BACKWARD))
316+
317+
w1 = convert(Complex{T}, cispi(2/N))
318+
wj = one(Complex{T})
319+
m = N ÷ 2
320+
@turbo for j in 2:m
321+
out[j] = out[j] + wj*out[j+m]
322+
out[m+j] = conj(out[m-i+2])
323+
wj *= w1
324+
end
325+
end
326+
327+
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}) where {T}
328+
N = length(out)
329+
wn² = wn = w = convert(T, cispi(-2/N))
330+
wn_1 = one(T)
331+
332+
tmp = in[1];
333+
out .= tmp;
276334
tmp = sum(in)
277-
out[1] = tmp
335+
out[1] = tmp;
278336

279337
wk = wn²;
280338
for d in 2:N
@@ -291,16 +349,15 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKW
291349
end
292350
end
293351

294-
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}) where {T<:Complex}
352+
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}) where {T}
295353
N = length(out)
296-
inc = 2*π/N
297-
wn² = wn = w = T(cos(inc), -sin(inc));
298-
wn_1 = T(1., 0.);
354+
wn² = wn = w = convert(T, cispi(2/N))
355+
wn_1 = one(T)
299356

300-
tmp = in[1];
301-
out .= tmp;
357+
tmp = in[1]
358+
out .= tmp
302359
tmp = sum(in)
303-
out[1] = tmp;
360+
out[1] = tmp
304361

305362
wk = wn²;
306363
for d in 2:N
@@ -317,11 +374,61 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWA
317374
end
318375
end
319376

377+
function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val{FFT_FORWARD}) where {T<:Real}
378+
N = length(out)
379+
halfN = N÷2
380+
wk = wkn = w = convert(Complex{T}, cispi(-2/N))
381+
382+
out[2:N] .= in[1]
383+
out[1] = sum(in)
384+
iseven(N) && (out[halfN+1] = alternatingSum(in))
385+
386+
for d in 2:halfN+1
387+
tmp = in[1]
388+
for k in 2:N
389+
tmp += wkn*in[k]
390+
wkn *= wk
391+
end
392+
out[d] = tmp
393+
wk *= w
394+
wkn = wk
395+
end
396+
@turbo for i in 0:halfN-1
397+
out[N-i] = conj(out[halfN-i])
398+
end
399+
end
400+
401+
function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}) where {T<:Real}
402+
N = length(out)
403+
halfN = N÷2
404+
wn² = wn = w = convert(Complex{T}, cispi(2/N))
405+
wn_1 = one(T)
406+
407+
out .= in[1]
408+
out[1] = sum(in)
409+
iseven(N) && (out[halfN+1] = alternatingSum(in))
410+
411+
wk = wn²;
412+
for d in 2:halfN
413+
out[d] = in[d]*wk + out[d]
414+
for k in (d+1):halfN
415+
wk *= wn
416+
out[d] = in[k]*wk + out[d]
417+
out[k] = in[d]*wk + out[k]
418+
end
419+
wn_1 = wn
420+
wn *= w
421+
wn² *= (wn*wn_1)
422+
wk = wn²
423+
end
424+
out[(N-halfN+2):end] .= conj.(out[halfN:-1:2])
425+
end
426+
320427

321-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}, ::DFT, ::CallGraph{T}, ::Int) where {T<:Complex}
428+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, ::Val{FFT_FORWARD}, ::DFT, ::CallGraph{T}, ::Int) where {T,U}
322429
fft_dft!(out, in, Val(FFT_FORWARD))
323430
end
324431

325-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}, ::DFT, ::CallGraph{T}, ::Int) where {T<:Complex}
432+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, ::Val{FFT_BACKWARD}, ::DFT, ::CallGraph{T}, ::Int) where {T,U}
326433
fft_dft!(out, in, Val(FFT_BACKWARD))
327434
end

0 commit comments

Comments
 (0)