Skip to content

Commit d3a1f8b

Browse files
committed
Polish, accelerate real dispatch
1 parent 819dbfb commit d3a1f8b

File tree

3 files changed

+44
-28
lines changed

3 files changed

+44
-28
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
8+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
89
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
910

1011
[compat]

src/FFTA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module FFTA
22

3-
using Primes, DocStringExtensions
3+
using Primes, DocStringExtensions, LoopVectorization
44
import Base: getindex
55
export fft, bfft
66

src/algos.jl

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
@enum Direction FFT_FORWARD FFT_BACKWARD
22
abstract type AbstractFFTType end
33

4+
function alternatingSum(x::AbstractVector{T}) where T
5+
y = x[1]
6+
@turbo for i in 2:length(x)
7+
y += (x[i] * convert(T,(2 * (i % 2) - 1)))
8+
end
9+
y
10+
end
11+
412
# Represents a Composite Cooley-Tukey FFT
513
struct CompositeFFT <: AbstractFFTType end
614

@@ -104,6 +112,13 @@ function fft(x::AbstractVector{T}) where {T}
104112
y
105113
end
106114

115+
function fft(x::AbstractVector{T}) where {T <: Real}
116+
y = similar(x, Complex{T})
117+
g = CallGraph{Complex{T}}(length(x))
118+
fft!(y, x, Val(FFT_FORWARD), g[1].type, g, 1)
119+
y
120+
end
121+
107122
function fft(x::AbstractMatrix{T}) where {T}
108123
M,N = size(x)
109124
y1 = similar(x)
@@ -277,11 +292,13 @@ function fft_pow2!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val
277292
w1 = convert(Complex{T}, cispi(-2/N))
278293
wj = one(Complex{T})
279294
m = N ÷ 2
280-
for j in 2:m
281-
out[j] = out[j] + wj*out[j+m]
295+
@turbo for j in 2:m
296+
out[j] = out[j] + wj*out[j+m]
282297
wj *= w1
283298
end
284-
out[m+2:end] = conj.(out[m:-1:2])
299+
@turbo for j in 2:m
300+
out[m+j] = conj(out[m-j+2])
301+
end
285302
end
286303

287304
"""
@@ -300,11 +317,11 @@ function fft_pow2!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val
300317
w1 = convert(Complex{T}, cispi(2/N))
301318
wj = one(Complex{T})
302319
m = N ÷ 2
303-
for j in 2:m
304-
out[j] = out[j] + wj*out[j+m]
320+
@turbo for j in 2:m
321+
out[j] = out[j] + wj*out[j+m]
322+
out[m+j] = conj(out[m-i+2])
305323
wj *= w1
306324
end
307-
out[m+2:end] = conj.(out[m:-1:2])
308325
end
309326

310327
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}) where {T}
@@ -360,38 +377,36 @@ end
360377
function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val{FFT_FORWARD}) where {T<:Real}
361378
N = length(out)
362379
halfN = N÷2
363-
wn² = wn = w = convert(T, cispi(-2/N))
364-
wn_1 = one(T)
380+
wk = wkn = w = convert(Complex{T}, cispi(-2/N))
365381

366-
out .= in[1]
382+
out[2:N] .= in[1]
367383
out[1] = sum(in)
368-
iseven(N) && (out[halfN+1] = foldr(-,in))
369-
370-
wk = wn²;
371-
for d in 2:halfN
372-
out[d] = in[d]*wk + out[d]
373-
for k in (d+1):halfN
374-
wk *= wn
375-
out[d] = in[k]*wk + out[d]
376-
out[k] = in[d]*wk + out[k]
384+
iseven(N) && (out[halfN+1] = alternatingSum(in))
385+
386+
for d in 2:halfN+1
387+
tmp = in[1]
388+
for k in 2:N
389+
tmp += wkn*in[k]
390+
wkn *= wk
377391
end
378-
wn_1 = wn
379-
wn *= w
380-
wn² *= (wn*wn_1)
381-
wk = wn²
392+
out[d] = tmp
393+
wk *= w
394+
wkn = wk
395+
end
396+
@turbo for i in 0:halfN-1
397+
out[N-i] = conj(out[halfN-i])
382398
end
383-
out[(N-halfN+2):end] .= conj.(out[halfN:-1:2])
384399
end
385400

386401
function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}) where {T<:Real}
387402
N = length(out)
388403
halfN = N÷2
389-
wn² = wn = w = convert(T, cispi(2/N))
404+
wn² = wn = w = convert(Complex{T}, cispi(2/N))
390405
wn_1 = one(T)
391406

392407
out .= in[1]
393408
out[1] = sum(in)
394-
iseven(N) && (out[halfN+1] = foldr(-,in))
409+
iseven(N) && (out[halfN+1] = alternatingSum(in))
395410

396411
wk = wn²;
397412
for d in 2:halfN
@@ -410,10 +425,10 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, ::Val{
410425
end
411426

412427

413-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_FORWARD}, ::DFT, ::CallGraph{T}, ::Int) where {T}
428+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, ::Val{FFT_FORWARD}, ::DFT, ::CallGraph{T}, ::Int) where {T,U}
414429
fft_dft!(out, in, Val(FFT_FORWARD))
415430
end
416431

417-
function fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{FFT_BACKWARD}, ::DFT, ::CallGraph{T}, ::Int) where {T}
432+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, ::Val{FFT_BACKWARD}, ::DFT, ::CallGraph{T}, ::Int) where {T,U}
418433
fft_dft!(out, in, Val(FFT_BACKWARD))
419434
end

0 commit comments

Comments
 (0)