Skip to content

Commit 6e12fb6

Browse files
authored
Test two dimensions, Add radix-4 transform (#19)
* Change direction to enum * Remove turbos * Add radix 4 transform * Two dimensional tests Co-authored-by: Danny Sharp <[email protected]>
1 parent 2734c13 commit 6e12fb6

13 files changed

+254
-74
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.2.0"
66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
88
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
9+
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
910
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112

src/FFTA.jl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module FFTA
22

3-
using Primes, DocStringExtensions, LoopVectorization
4-
import Base: getindex
3+
using Primes, DocStringExtensions, LoopVectorization, MuladdMacro
54
export fft, bfft
65

76
include("callgraph.jl")
@@ -38,6 +37,23 @@ function fft(x::AbstractMatrix{T}) where {T}
3837
y2
3938
end
4039

40+
function fft(x::AbstractMatrix{T}) where {T <: Real}
41+
M,N = size(x)
42+
y1 = similar(x, Complex{T})
43+
y2 = similar(x, Complex{T})
44+
g1 = CallGraph{Complex{T}}(size(x,1))
45+
g2 = CallGraph{Complex{T}}(size(x,2))
46+
47+
for k in 1:N
48+
@views fft!(y1[:,k], x[:,k], 1, 1, FFT_FORWARD, g1[1].type, g1, 1)
49+
end
50+
51+
for k in 1:M
52+
@views fft!(y2[k,:], y1[k,:], 1, 1, FFT_FORWARD, g2[1].type, g2, 1)
53+
end
54+
y2
55+
end
56+
4157
function bfft(x::AbstractVector{T}) where {T}
4258
y = similar(x)
4359
g = CallGraph{T}(length(x))
@@ -69,4 +85,21 @@ function bfft(x::AbstractMatrix{T}) where {T}
6985
y2
7086
end
7187

88+
function bfft(x::AbstractMatrix{T}) where {T <: Real}
89+
M,N = size(x)
90+
y1 = similar(x, Complex{T})
91+
y2 = similar(x, Complex{T})
92+
g1 = CallGraph{Complex{T}}(size(x,1))
93+
g2 = CallGraph{Complex{T}}(size(x,2))
94+
95+
for k in 1:N
96+
@views fft!(y1[:,k], x[:,k], 1, 1, FFT_BACKWARD, g1[1].type, g1, 1)
97+
end
98+
99+
for k in 1:M
100+
@views fft!(y2[k,:], y1[k,:], 1, 1, FFT_BACKWARD, g2[1].type, g2, 1)
101+
end
102+
y2
103+
end
104+
72105
end

src/algos.jl

Lines changed: 93 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
function alternatingSum(x::AbstractVector{T}) where T
2-
y = x[1]
3-
@turbo for i in 2:length(x)
4-
y += (x[i] * convert(T,(2 * (i % 2) - 1)))
5-
end
6-
y
7-
end
8-
91
fft!(::AbstractVector{T}, ::AbstractVector{T}, ::Int, ::Int, ::Direction, ::AbstractFFTType, ::CallGraph{T}, ::Int) where {T} = nothing
102

113
@inline function direction_sign(d::Direction)
@@ -31,56 +23,21 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, sta
3123
w1 = convert(T, cispi(direction_sign(d)*2/N))
3224
wj1 = one(T)
3325
tmp = g.workspace[idx]
34-
for j1 in 0:N1-1
26+
@inbounds for j1 in 0:N1-1
3527
wk2 = wj1
3628
g(tmp, in, N2*j1+1, start_in + j1*s_in, d, right.type, right_idx)
37-
j1 > 0 && for k2 in 1:N2-1
29+
j1 > 0 && @inbounds for k2 in 1:N2-1
3830
tmp[N2*j1 + k2 + 1] *= wk2
3931
wk2 *= wj1
4032
end
4133
wj1 *= w1
4234
end
4335

44-
for k2 in 0:N2-1
36+
@inbounds for k2 in 0:N2-1
4537
g(out, tmp, start_out + k2*s_out, k2+1, d, left.type, left_idx)
4638
end
4739
end
4840

49-
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::Pow2FFT, g::CallGraph{T}, idx::Int) where {T,U}
50-
root = g[idx]
51-
N = root.sz
52-
s_in = root.s_in
53-
s_out = root.s_out
54-
fft_pow2!(out, in, N, start_out, s_out, start_in, s_in, d)
55-
end
56-
57-
"""
58-
Power of 2 FFT in place, complex
59-
60-
"""
61-
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T, U}
62-
if N == 2
63-
out[start_out] = in[start_in] + in[start_in + stride_in]
64-
out[start_out + stride_out] = in[start_in] - in[start_in + stride_in]
65-
return
66-
end
67-
m = N ÷ 2
68-
69-
fft_pow2!(out, in, m, start_out , stride_out, start_in , stride_in*2, d)
70-
fft_pow2!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*2, d)
71-
72-
w1 = convert(T, cispi(direction_sign(d)*2/N))
73-
wj = one(T)
74-
@inbounds for j in 0:m-1
75-
j1_out = start_out + j*stride_out
76-
j2_out = start_out + (j+m)*stride_out
77-
out_j = out[j1_out]
78-
out[j1_out] = out_j + wj*out[j2_out]
79-
out[j2_out] = out_j - wj*out[j2_out]
80-
wj *= w1
81-
end
82-
end
83-
8441
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T}
8542
tmp = in[start_in]
8643
@inbounds for j in 1:N-1
@@ -131,4 +88,93 @@ end
13188
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::DFT, g::CallGraph{T}, idx::Int) where {T,U}
13289
root = g[idx]
13390
fft_dft!(out, in, root.sz, start_out, root.s_out, start_in, root.s_in, d)
134-
end
91+
end
92+
93+
"""
94+
Power of 2 FFT in place
95+
96+
"""
97+
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T, U}
98+
if N == 2
99+
out[start_out] = in[start_in] + in[start_in + stride_in]
100+
out[start_out + stride_out] = in[start_in] - in[start_in + stride_in]
101+
return
102+
end
103+
m = N ÷ 2
104+
105+
fft_pow2!(out, in, m, start_out , stride_out, start_in , stride_in*2, d)
106+
fft_pow2!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*2, d)
107+
108+
w1 = convert(T, cispi(direction_sign(d)*2/N))
109+
wj = one(T)
110+
@inbounds for j in 0:m-1
111+
j1_out = start_out + j*stride_out
112+
j2_out = start_out + (j+m)*stride_out
113+
out_j = out[j1_out]
114+
out[j1_out] = out_j + wj*out[j2_out]
115+
out[j2_out] = out_j - wj*out[j2_out]
116+
wj *= w1
117+
end
118+
end
119+
120+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::Pow2FFT, g::CallGraph{T}, idx::Int) where {T,U}
121+
root = g[idx]
122+
N = root.sz
123+
s_in = root.s_in
124+
s_out = root.s_out
125+
fft_pow2!(out, in, N, start_out, s_out, start_in, s_in, d)
126+
end
127+
128+
"""
129+
Power of 4 FFT in place
130+
131+
"""
132+
function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T, U}
133+
ds = direction_sign(d)
134+
plusi = ds*1im
135+
minusi = ds*-1im
136+
if N == 4
137+
out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in] + in[start_in + 3*stride_in]
138+
out[start_out + stride_out] = in[start_in] + in[start_in + stride_in]*plusi - in[start_in + 2*stride_in] + in[start_in + 3*stride_in]*minusi
139+
out[start_out + 2*stride_out] = in[start_in] - in[start_in + stride_in] + in[start_in + 2*stride_in] - in[start_in + 3*stride_in]
140+
out[start_out + 3*stride_out] = in[start_in] + in[start_in + stride_in]*minusi - in[start_in + 2*stride_in] + in[start_in + 3*stride_in]*plusi
141+
return
142+
end
143+
m = N ÷ 4
144+
145+
@muladd fft_pow4!(out, in, m, start_out , stride_out, start_in , stride_in*4, d)
146+
@muladd fft_pow4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, d)
147+
@muladd fft_pow4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, d)
148+
@muladd fft_pow4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, d)
149+
150+
w1 = convert(T, cispi(direction_sign(d)*2/N))
151+
wj = one(T)
152+
153+
w1 = convert(T, cispi(ds*2/N))
154+
w2 = convert(T, cispi(ds*4/N))
155+
w3 = convert(T, cispi(ds*6/N))
156+
wk1 = wk2 = wk3 = one(T)
157+
158+
@inbounds for k in 0:m-1
159+
@muladd k0 = start_out + k*stride_out
160+
@muladd k1 = start_out + (k+m)*stride_out
161+
@muladd k2 = start_out + (k+2*m)*stride_out
162+
@muladd k3 = start_out + (k+3*m)*stride_out
163+
y_k0, y_k1, y_k2, y_k3 = out[k0], out[k1], out[k2], out[k3]
164+
@muladd out[k0] = (y_k0 + y_k2*wk2) + (y_k1*wk1 + y_k3*wk2)
165+
@muladd out[k1] = (y_k0 - y_k2*wk2) + (y_k1*wk1 - y_k3*wk3) * plusi
166+
@muladd out[k2] = (y_k0 + y_k2*wk2) - (y_k1*wk1 + y_k3*wk3)
167+
@muladd out[k3] = (y_k0 - y_k2*wk2) + (y_k1*wk1 - y_k3*wk3) * minusi
168+
wk1 *= w1
169+
wk2 *= w2
170+
wk3 *= w3
171+
end
172+
end
173+
174+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::Pow4FFT, g::CallGraph{T}, idx::Int) where {T,U}
175+
root = g[idx]
176+
N = root.sz
177+
s_in = root.s_in
178+
s_out = root.s_out
179+
fft_pow4!(out, in, N, start_out, s_out, start_in, s_in, d)
180+
end

src/callgraph.jl

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@enum Direction FFT_FORWARD=-1 FFT_BACKWARD=1
2+
@enum Pow24 POW2=2 POW4=1
23

34
abstract type AbstractFFTType end
45

@@ -8,6 +9,9 @@ struct CompositeFFT <: AbstractFFTType end
89
# Represents a Radix-2 Cooley-Tukey FFT
910
struct Pow2FFT <: AbstractFFTType end
1011

12+
# Represents a Radix-4 Cooley-Tukey FFT
13+
struct Pow4FFT <: AbstractFFTType end
14+
1115
# Represents an O(N²) DFT
1216
struct DFT <: AbstractFFTType end
1317

@@ -62,16 +66,49 @@ leftNode(g::CallGraph, i::Int) = g[i+g[i].left]
6266
# Get the right child of the node at index `i`
6367
rightNode(g::CallGraph, i::Int) = g[i+g[i].right]
6468

69+
function _ispow(N, base)
70+
if base == 2
71+
while N & 0b1 == 0
72+
N >>= 1
73+
end
74+
return N == 1
75+
elseif base == 4
76+
while N & 0b11 == 0
77+
N >>= 2
78+
end
79+
return N == 1
80+
else
81+
while N % base == 0
82+
N = N/base
83+
end
84+
return N == 1
85+
end
86+
end
87+
88+
function _ispow24(N::Int)
89+
N < 1 && return nothing
90+
while N & 0b11 == 0
91+
N >>= 2
92+
end
93+
return N < 3 ? Pow24(N) : nothing
94+
end
95+
6596
# Recursively instantiate a set of `CallGraphNode`s
6697
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
67-
facs = factor(N)
68-
Ns = [first(x) for x in collect(facs) for _ in 1:last(x)]
69-
if length(Ns) == 1 || Ns[end] == 2
98+
if N % 2 == 0
99+
pow = _ispow24(N)
100+
if !isnothing(pow)
101+
push!(workspace, T[])
102+
push!(nodes, CallGraphNode(0, 0, pow == POW2 ? Pow2FFT() : Pow4FFT(), N, s_in, s_out))
103+
return 1
104+
end
105+
end
106+
if isprime(N)
70107
push!(workspace, T[])
71-
push!(nodes, CallGraphNode(0,0,Ns[end] == 2 ? Pow2FFT() : DFT(),N, s_in, s_out))
108+
push!(nodes, CallGraphNode(0,0, DFT(),N, s_in, s_out))
72109
return 1
73110
end
74-
111+
Ns = [first(x) for x in collect(factor(N)) for _ in 1:last(x)]
75112
if Ns[1] == 2
76113
N1 = prod(Ns[Ns .== 2])
77114
else

test/complex_backward.jl renamed to test/onedim/complex_backward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 100]
2+
test_nums = [8, 11, 15, 16, 100]
33
@testset "backward" begin
44
for N in test_nums
55
x = ones(ComplexF64, N)

test/complex_forward.jl renamed to test/onedim/complex_forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 100]
2+
test_nums = [8, 11, 15, 16, 100]
33
@testset verbose = true " forward" begin
44
for N in test_nums
55
x = ones(ComplexF64, N)

test/real_backward.jl renamed to test/onedim/real_backward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 100]
2+
test_nums = [8, 11, 15, 16, 100]
33
@testset "backward" begin
44
for N in test_nums
55
x = ones(Float64, N)

test/real_forward.jl renamed to test/onedim/real_forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 100]
2+
test_nums = [8, 11, 15, 16, 100]
33
@testset verbose = true " forward" begin
44
for N in test_nums
55
x = ones(Float64, N)

test/runtests.jl

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,41 @@ function padnum(m,x)
99
end
1010

1111
Random.seed!(1)
12-
13-
@testset verbose = true "1D" begin
14-
@testset verbose = true "Complex" begin
15-
include("complex_forward.jl")
16-
include("complex_backward.jl")
17-
x = rand(ComplexF64, 100)
18-
y = fft(x)
19-
x2 = bfft(y)/length(x)
20-
@test x x2 atol=1e-12
12+
@testset verbose = true "FFTA" begin
13+
@testset verbose = true "1D" begin
14+
@testset verbose = true "Complex" begin
15+
include("onedim/complex_forward.jl")
16+
include("onedim/complex_backward.jl")
17+
x = rand(ComplexF64, 100)
18+
y = fft(x)
19+
x2 = bfft(y)/length(x)
20+
@test x x2 atol=1e-12
21+
end
22+
@testset verbose = true "Real" begin
23+
include("onedim/real_forward.jl")
24+
include("onedim/real_backward.jl")
25+
x = rand(Float64, 100)
26+
y = fft(x)
27+
x2 = bfft(y)/length(x)
28+
@test x x2 atol=1e-12
29+
end
2130
end
22-
@testset verbose = true "Real" begin
23-
include("real_forward.jl")
24-
include("real_backward.jl")
25-
x = rand(Float64, 100)
26-
y = fft(x)
27-
x2 = bfft(y)/length(x)
28-
@test x x2 atol=1e-12
31+
@testset verbose = true "2D" begin
32+
@testset verbose = true "Complex" begin
33+
include("twodim/complex_forward.jl")
34+
include("twodim/complex_backward.jl")
35+
x = rand(ComplexF64, 100, 100)
36+
y = fft(x)
37+
x2 = bfft(y)/length(x)
38+
@test x x2
39+
end
40+
@testset verbose = true "Real" begin
41+
include("twodim/real_forward.jl")
42+
include("twodim/real_backward.jl")
43+
x = rand(Float64, 100, 100)
44+
y = fft(x)
45+
x2 = bfft(y)/length(x)
46+
@test x x2
47+
end
2948
end
3049
end

0 commit comments

Comments
 (0)