Skip to content

Commit 3b47b59

Browse files
committed
Fix the size one case and provide informative error message in the
size zero case. Currently, the size one case errors and the size zero case hangs. I've also added more tests. They revealed a bug in the pow4 fft which I have also fixed.
1 parent 171bee0 commit 3b47b59

File tree

9 files changed

+120
-26
lines changed

9 files changed

+120
-26
lines changed

src/algos.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_o
7373
tmp += in[start_in + j*stride_in]
7474
end
7575
out[start_out] = tmp
76-
76+
7777
wk = wkn = w = convert(T, cispi(direction_sign(d)*2/N))
7878
@inbounds for d in 1:N-1
7979
tmp = in[start_in]
@@ -98,7 +98,7 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int
9898
end
9999
out[start_out] = convert(Complex{T}, tmpBegin)
100100
iseven(N) && (out[start_out + stride_out*halfN] = convert(Complex{T}, tmpHalf))
101-
101+
102102
@inbounds for d in 1:halfN
103103
tmp = in[start_in]
104104
@inbounds for k in 1:N-1
@@ -200,7 +200,7 @@ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
200200

201201
w1 = convert(T, cispi(direction_sign(d)*2/N))
202202
wj = one(T)
203-
203+
204204
w1 = convert(T, cispi(ds*2/N))
205205
w2 = convert(T, cispi(ds*4/N))
206206
w3 = convert(T, cispi(ds*6/N))
@@ -212,7 +212,7 @@ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
212212
@muladd k2 = start_out + (k+2*m)*stride_out
213213
@muladd k3 = start_out + (k+3*m)*stride_out
214214
y_k0, y_k1, y_k2, y_k3 = out[k0], out[k1], out[k2], out[k3]
215-
@muladd out[k0] = (y_k0 + y_k2*wk2) + (y_k1*wk1 + y_k3*wk2)
215+
@muladd out[k0] = (y_k0 + y_k2*wk2) + (y_k1*wk1 + y_k3*wk3)
216216
@muladd out[k1] = (y_k0 - y_k2*wk2) + (y_k1*wk1 - y_k3*wk3) * plusi
217217
@muladd out[k2] = (y_k0 + y_k2*wk2) - (y_k1*wk1 + y_k3*wk3)
218218
@muladd out[k3] = (y_k0 - y_k2*wk2) + (y_k1*wk1 - y_k3*wk3) * minusi

src/callgraph.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,6 @@ end
5555
# Get the node in the graph at index i
5656
Base.getindex(g::CallGraph{T}, i::Int) where {T} = g.nodes[i]
5757

58-
"""
59-
$(TYPEDSIGNATURES)
60-
Check if `N` is a power of `base`
61-
62-
"""
63-
function _ispow(N, base)
64-
while N % base == 0
65-
N = N/base
66-
end
67-
return N == 1
68-
end
69-
7058
"""
7159
$(TYPEDSIGNATURES)
7260
Check if `N` is a power of 2 or 4
@@ -93,6 +81,9 @@ Recursively instantiate a set of `CallGraphNode`s
9381
9482
"""
9583
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
84+
if N == 0
85+
throw(DimensionMismatch("array has to be non-empty"))
86+
end
9687
if iseven(N)
9788
pow = _ispow24(N)
9889
if !isnothing(pow)
@@ -102,15 +93,15 @@ function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{
10293
end
10394
end
10495
if N % 3 == 0
105-
if _ispow(N, 3)
96+
if nextpow(3, N) == N
10697
push!(workspace, T[])
10798
push!(nodes, CallGraphNode(0, 0, Pow3FFT(), N, s_in, s_out))
10899
return 1
109100
end
110101
end
111-
if isprime(N)
102+
if N == 1 || isprime(N)
112103
push!(workspace, T[])
113-
push!(nodes, CallGraphNode(0,0, DFT(),N, s_in, s_out))
104+
push!(nodes, CallGraphNode(0, 0, DFT(), N, s_in, s_out))
114105
return 1
115106
end
116107
Ns = [first(x) for x in collect(factor(N)) for _ in 1:last(x)]

test/onedim/complex_backward.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,13 @@ test_nums = [8, 11, 15, 16, 27, 100]
88
y_ref[1] = N
99
@test y y_ref atol=1e-12
1010
end
11-
end
11+
end
12+
13+
@testset verbose = true "against naive implementation. Size: $n" for n in 1:64
14+
x = complex.(randn(n), randn(n))
15+
@test naive_1d_fourier_transform(x, FFTA.FFT_BACKWARD) bfft(x)
16+
end
17+
18+
@testset "error messages" begin
19+
@test_throws DimensionMismatch bfft(zeros(0))
20+
end

test/onedim/complex_forward.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,13 @@ test_nums = [8, 11, 15, 16, 27, 100]
88
y_ref[1] = N
99
@test y y_ref atol=1e-12
1010
end
11-
end
11+
end
12+
13+
@testset verbose = true "against naive implementation. Size: $n" for n in 1:64
14+
x = complex.(randn(n), randn(n))
15+
@test naive_1d_fourier_transform(x, FFTA.FFT_FORWARD) fft(x)
16+
end
17+
18+
@testset "error messages" begin
19+
@test_throws DimensionMismatch fft(zeros(0))
20+
end

test/onedim/real_backward.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,20 @@ test_nums = [8, 11, 15, 16, 27, 100]
1111
end
1212
@test y_ref y atol=1e-12
1313
end
14-
end
14+
end
15+
16+
@testset verbose = true "against naive implementation. Size: $n" for n in 1:64
17+
x = complex.(randn(n ÷ 2 + 1), randn(n ÷ 2 + 1))
18+
x[begin] = real(x[begin])
19+
if iseven(n)
20+
x[end] = real(x[end])
21+
xe = [x; conj.(reverse(x[begin + 1:end - 1]))]
22+
else
23+
xe = [x; conj.(reverse(x[begin + 1:end]))]
24+
end
25+
@test naive_1d_fourier_transform(xe, FFTA.FFT_BACKWARD) brfft(x, n)
26+
end
27+
28+
@testset "error messages" begin
29+
@test_throws DimensionMismatch brfft(zeros(ComplexF64, 0), 0)
30+
end

test/onedim/real_forward.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
using FFTA, Test
22
test_nums = [8, 11, 15, 16, 27, 100]
3-
@testset verbose = true " forward" begin
3+
@testset verbose = true " forward" begin
44
for N in test_nums
55
x = ones(Float64, N)
66
y = rfft(x)
77
y_ref = 0*y
88
y_ref[1] = N
99
@test y y_ref atol=1e-12
1010
end
11+
end
12+
13+
@testset verbose = true "against naive implementation. Size: $n" for n in 1:64
14+
x = randn(n)
15+
@test naive_1d_fourier_transform(x, FFTA.FFT_FORWARD)[1:(n ÷ 2 + 1)] rfft(x)
16+
end
17+
18+
@testset "error messages" begin
19+
@test_throws DimensionMismatch rfft(zeros(0))
1120
end

test/runtests.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, Random
1+
using Test, Random, FFTA
22

33
function padnum(m,x)
44
digs = floor(Int, log10(m))
@@ -8,6 +8,42 @@ function padnum(m,x)
88
String(v)
99
end
1010

11+
function naive_1d_fourier_transform(x::Vector, d::FFTA.Direction)
12+
n = length(x)
13+
y = zeros(Complex{Float64}, n)
14+
15+
for u in 0:(n - 1)
16+
s = 0.0 + 0.0im
17+
for v in 0:(n - 1)
18+
a = FFTA.direction_sign(d) * 2π * u * v / n
19+
s += x[v + 1] * exp(im * a)
20+
end
21+
y[u + 1] = s
22+
end
23+
24+
return y
25+
end
26+
27+
function naive_2d_fourier_transform(X::Matrix, d::FFTA.Direction)
28+
rows, cols = size(X)
29+
Y = zeros(Complex{Float64}, rows, cols)
30+
31+
for u in 0:(rows - 1)
32+
for v in 0:(cols - 1)
33+
s = 0.0 + 0.0im
34+
for x in 0:(rows - 1)
35+
for y in 0:(cols - 1)
36+
a = FFTA.direction_sign(d) * 2π * (u * x / rows + v * y / cols)
37+
s += X[x + 1, y + 1] * exp(im * a)
38+
end
39+
end
40+
Y[u + 1, v + 1] = s
41+
end
42+
end
43+
44+
return Y
45+
end
46+
1147
Random.seed!(1)
1248
@testset verbose = true "FFTA" begin
1349
@testset verbose = true "1D" begin

test/twodim/complex_backward.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,16 @@ test_nums = [8, 11, 15, 16, 27, 100]
88
y_ref[1] = length(x)
99
@test y y_ref
1010
end
11-
end
11+
end
12+
13+
@testset verbose = true "against naive implementation" for n in 1:64
14+
@testset "size: ($m, $n)" for m in n:(n + 1)
15+
X = complex.(randn(m, n), randn(m, n))
16+
Y = similar(X)
17+
@test naive_2d_fourier_transform(X, FFTA.FFT_BACKWARD) bfft(X)
18+
end
19+
end
20+
21+
@testset "error messages" begin
22+
@test_throws DimensionMismatch bfft(zeros(ComplexF64, 0, 0))
23+
end

test/twodim/complex_forward.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,16 @@ test_nums = [8, 11, 15, 16, 27, 100]
88
y_ref[1] = length(x)
99
@test y y_ref
1010
end
11-
end
11+
end
12+
13+
@testset verbose = true "against naive implementation" for n in 1:64
14+
@testset "size: ($m, $n)" for m in n:(n + 1)
15+
X = complex.(randn(m, n), randn(m, n))
16+
Y = similar(X)
17+
@test naive_2d_fourier_transform(X, FFTA.FFT_FORWARD) fft(X)
18+
end
19+
end
20+
21+
@testset "error messages" begin
22+
@test_throws DimensionMismatch fft(zeros(0, 0))
23+
end

0 commit comments

Comments
 (0)