Skip to content

Commit d577a6b

Browse files
authored
Merge pull request #47 from andreasnoack/an/fix01
Fix the size one case and provide informative error message in the size zero case.
2 parents 171bee0 + 3b47b59 commit d577a6b

File tree

9 files changed

+120
-26
lines changed

9 files changed

+120
-26
lines changed

src/algos.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_o
7373
tmp += in[start_in + j*stride_in]
7474
end
7575
out[start_out] = tmp
76-
76+
7777
wk = wkn = w = convert(T, cispi(direction_sign(d)*2/N))
7878
@inbounds for d in 1:N-1
7979
tmp = in[start_in]
@@ -98,7 +98,7 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int
9898
end
9999
out[start_out] = convert(Complex{T}, tmpBegin)
100100
iseven(N) && (out[start_out + stride_out*halfN] = convert(Complex{T}, tmpHalf))
101-
101+
102102
@inbounds for d in 1:halfN
103103
tmp = in[start_in]
104104
@inbounds for k in 1:N-1
@@ -200,7 +200,7 @@ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
200200

201201
w1 = convert(T, cispi(direction_sign(d)*2/N))
202202
wj = one(T)
203-
203+
204204
w1 = convert(T, cispi(ds*2/N))
205205
w2 = convert(T, cispi(ds*4/N))
206206
w3 = convert(T, cispi(ds*6/N))
@@ -212,7 +212,7 @@ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
212212
@muladd k2 = start_out + (k+2*m)*stride_out
213213
@muladd k3 = start_out + (k+3*m)*stride_out
214214
y_k0, y_k1, y_k2, y_k3 = out[k0], out[k1], out[k2], out[k3]
215-
@muladd out[k0] = (y_k0 + y_k2*wk2) + (y_k1*wk1 + y_k3*wk2)
215+
@muladd out[k0] = (y_k0 + y_k2*wk2) + (y_k1*wk1 + y_k3*wk3)
216216
@muladd out[k1] = (y_k0 - y_k2*wk2) + (y_k1*wk1 - y_k3*wk3) * plusi
217217
@muladd out[k2] = (y_k0 + y_k2*wk2) - (y_k1*wk1 + y_k3*wk3)
218218
@muladd out[k3] = (y_k0 - y_k2*wk2) + (y_k1*wk1 - y_k3*wk3) * minusi

src/callgraph.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,6 @@ end
5555
# Get the node in the graph at index i
5656
Base.getindex(g::CallGraph{T}, i::Int) where {T} = g.nodes[i]
5757

58-
"""
59-
$(TYPEDSIGNATURES)
60-
Check if `N` is a power of `base`
61-
62-
"""
63-
function _ispow(N, base)
64-
while N % base == 0
65-
N = N/base
66-
end
67-
return N == 1
68-
end
69-
7058
"""
7159
$(TYPEDSIGNATURES)
7260
Check if `N` is a power of 2 or 4
@@ -93,6 +81,9 @@ Recursively instantiate a set of `CallGraphNode`s
9381
9482
"""
9583
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
84+
if N == 0
85+
throw(DimensionMismatch("array has to be non-empty"))
86+
end
9687
if iseven(N)
9788
pow = _ispow24(N)
9889
if !isnothing(pow)
@@ -102,15 +93,15 @@ function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{
10293
end
10394
end
10495
if N % 3 == 0
105-
if _ispow(N, 3)
96+
if nextpow(3, N) == N
10697
push!(workspace, T[])
10798
push!(nodes, CallGraphNode(0, 0, Pow3FFT(), N, s_in, s_out))
10899
return 1
109100
end
110101
end
111-
if isprime(N)
102+
if N == 1 || isprime(N)
112103
push!(workspace, T[])
113-
push!(nodes, CallGraphNode(0,0, DFT(),N, s_in, s_out))
104+
push!(nodes, CallGraphNode(0, 0, DFT(), N, s_in, s_out))
114105
return 1
115106
end
116107
Ns = [first(x) for x in collect(factor(N)) for _ in 1:last(x)]

test/onedim/complex_backward.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,13 @@ test_nums = [8, 11, 15, 16, 27, 100]
88
y_ref[1] = N
99
@test y y_ref atol=1e-12
1010
end
11-
end
11+
end
12+
13+
@testset verbose = true "against naive implementation. Size: $n" for n in 1:64
14+
x = complex.(randn(n), randn(n))
15+
@test naive_1d_fourier_transform(x, FFTA.FFT_BACKWARD) bfft(x)
16+
end
17+
18+
@testset "error messages" begin
19+
@test_throws DimensionMismatch bfft(zeros(0))
20+
end

test/onedim/complex_forward.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,13 @@ test_nums = [8, 11, 15, 16, 27, 100]
88
y_ref[1] = N
99
@test y y_ref atol=1e-12
1010
end
11-
end
11+
end
12+
13+
@testset verbose = true "against naive implementation. Size: $n" for n in 1:64
14+
x = complex.(randn(n), randn(n))
15+
@test naive_1d_fourier_transform(x, FFTA.FFT_FORWARD) fft(x)
16+
end
17+
18+
@testset "error messages" begin
19+
@test_throws DimensionMismatch fft(zeros(0))
20+
end

test/onedim/real_backward.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,20 @@ test_nums = [8, 11, 15, 16, 27, 100]
1111
end
1212
@test y_ref y atol=1e-12
1313
end
14-
end
14+
end
15+
16+
@testset verbose = true "against naive implementation. Size: $n" for n in 1:64
17+
x = complex.(randn(n ÷ 2 + 1), randn(n ÷ 2 + 1))
18+
x[begin] = real(x[begin])
19+
if iseven(n)
20+
x[end] = real(x[end])
21+
xe = [x; conj.(reverse(x[begin + 1:end - 1]))]
22+
else
23+
xe = [x; conj.(reverse(x[begin + 1:end]))]
24+
end
25+
@test naive_1d_fourier_transform(xe, FFTA.FFT_BACKWARD) brfft(x, n)
26+
end
27+
28+
@testset "error messages" begin
29+
@test_throws DimensionMismatch brfft(zeros(ComplexF64, 0), 0)
30+
end

test/onedim/real_forward.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
using FFTA, Test
22
test_nums = [8, 11, 15, 16, 27, 100]
3-
@testset verbose = true " forward" begin
3+
@testset verbose = true " forward" begin
44
for N in test_nums
55
x = ones(Float64, N)
66
y = rfft(x)
77
y_ref = 0*y
88
y_ref[1] = N
99
@test y y_ref atol=1e-12
1010
end
11+
end
12+
13+
@testset verbose = true "against naive implementation. Size: $n" for n in 1:64
14+
x = randn(n)
15+
@test naive_1d_fourier_transform(x, FFTA.FFT_FORWARD)[1:(n ÷ 2 + 1)] rfft(x)
16+
end
17+
18+
@testset "error messages" begin
19+
@test_throws DimensionMismatch rfft(zeros(0))
1120
end

test/runtests.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, Random
1+
using Test, Random, FFTA
22

33
function padnum(m,x)
44
digs = floor(Int, log10(m))
@@ -8,6 +8,42 @@ function padnum(m,x)
88
String(v)
99
end
1010

11+
function naive_1d_fourier_transform(x::Vector, d::FFTA.Direction)
12+
n = length(x)
13+
y = zeros(Complex{Float64}, n)
14+
15+
for u in 0:(n - 1)
16+
s = 0.0 + 0.0im
17+
for v in 0:(n - 1)
18+
a = FFTA.direction_sign(d) * 2π * u * v / n
19+
s += x[v + 1] * exp(im * a)
20+
end
21+
y[u + 1] = s
22+
end
23+
24+
return y
25+
end
26+
27+
function naive_2d_fourier_transform(X::Matrix, d::FFTA.Direction)
28+
rows, cols = size(X)
29+
Y = zeros(Complex{Float64}, rows, cols)
30+
31+
for u in 0:(rows - 1)
32+
for v in 0:(cols - 1)
33+
s = 0.0 + 0.0im
34+
for x in 0:(rows - 1)
35+
for y in 0:(cols - 1)
36+
a = FFTA.direction_sign(d) * 2π * (u * x / rows + v * y / cols)
37+
s += X[x + 1, y + 1] * exp(im * a)
38+
end
39+
end
40+
Y[u + 1, v + 1] = s
41+
end
42+
end
43+
44+
return Y
45+
end
46+
1147
Random.seed!(1)
1248
@testset verbose = true "FFTA" begin
1349
@testset verbose = true "1D" begin

test/twodim/complex_backward.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,16 @@ test_nums = [8, 11, 15, 16, 27, 100]
88
y_ref[1] = length(x)
99
@test y y_ref
1010
end
11-
end
11+
end
12+
13+
@testset verbose = true "against naive implementation" for n in 1:64
14+
@testset "size: ($m, $n)" for m in n:(n + 1)
15+
X = complex.(randn(m, n), randn(m, n))
16+
Y = similar(X)
17+
@test naive_2d_fourier_transform(X, FFTA.FFT_BACKWARD) bfft(X)
18+
end
19+
end
20+
21+
@testset "error messages" begin
22+
@test_throws DimensionMismatch bfft(zeros(ComplexF64, 0, 0))
23+
end

test/twodim/complex_forward.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,16 @@ test_nums = [8, 11, 15, 16, 27, 100]
88
y_ref[1] = length(x)
99
@test y y_ref
1010
end
11-
end
11+
end
12+
13+
@testset verbose = true "against naive implementation" for n in 1:64
14+
@testset "size: ($m, $n)" for m in n:(n + 1)
15+
X = complex.(randn(m, n), randn(m, n))
16+
Y = similar(X)
17+
@test naive_2d_fourier_transform(X, FFTA.FFT_FORWARD) fft(X)
18+
end
19+
end
20+
21+
@testset "error messages" begin
22+
@test_throws DimensionMismatch fft(zeros(0, 0))
23+
end

0 commit comments

Comments
 (0)