Skip to content

Commit 0aa65af

Browse files
authored
Merge pull request #29 from dannys4/ds/pow3
Creating radix 3 FFT
2 parents 5f83277 + 59b6ea8 commit 0aa65af

11 files changed

+70
-10
lines changed

src/algos.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,51 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, sta
178178
s_out = root.s_out
179179
fft_pow4!(out, in, N, start_out, s_out, start_in, s_in, d)
180180
end
181+
182+
function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction, plus120::T, minus120::T) where {T, U}
183+
if N == 3
184+
@muladd out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in]
185+
@muladd out[start_out + stride_out] = in[start_in] + in[start_in + stride_in]*plus120 + in[start_in + 2*stride_in]*minus120
186+
@muladd out[start_out + 2*stride_out] = in[start_in] + in[start_in + stride_in]*minus120 + in[start_in + 2*stride_in]*plus120
187+
return
188+
end
189+
190+
# Size of subproblem
191+
Nprime = N ÷ 3
192+
193+
ds = direction_sign(d)
194+
195+
# Dividing into subproblems
196+
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, d, plus120, minus120)
197+
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, d, plus120, minus120)
198+
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, d, plus120, minus120)
199+
200+
w1 = convert(T, cispi(ds*2/N))
201+
w2 = convert(T, cispi(ds*4/N))
202+
wk1 = wk2 = one(T)
203+
for k in 0:Nprime-1
204+
@muladd k0 = start_out + k*stride_out
205+
@muladd k1 = start_out + (k+Nprime)*stride_out
206+
@muladd k2 = start_out + (k+2*Nprime)*stride_out
207+
y_k0, y_k1, y_k2 = out[k0], out[k1], out[k2]
208+
@muladd out[k0] = y_k0 + y_k1*wk1 + y_k2*wk2
209+
@muladd out[k1] = y_k0 + y_k1*wk1*plus120 + y_k2*wk2*minus120
210+
@muladd out[k2] = y_k0 + y_k1*wk1*minus120 + y_k2*wk2*plus120
211+
wk1 *= w1
212+
wk2 *= w2
213+
end
214+
end
215+
216+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::Pow3FFT, g::CallGraph{T}, idx::Int) where {T,U}
217+
root = g[idx]
218+
N = root.sz
219+
s_in = root.s_in
220+
s_out = root.s_out
221+
p_120 = convert(T, cispi(2/3))
222+
m_120 = convert(T, cispi(4/3))
223+
if d == FFT_FORWARD
224+
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, d, m_120, p_120)
225+
else
226+
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, d, p_120, m_120)
227+
end
228+
end

src/callgraph.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ struct CompositeFFT <: AbstractFFTType end
99
# Represents a Radix-2 Cooley-Tukey FFT
1010
struct Pow2FFT <: AbstractFFTType end
1111

12+
# Represents a Radix-3 Cooley-Tukey FFT
13+
struct Pow3FFT <: AbstractFFTType end
14+
1215
# Represents a Radix-4 Cooley-Tukey FFT
1316
struct Pow4FFT <: AbstractFFTType end
1417

@@ -95,14 +98,21 @@ end
9598

9699
# Recursively instantiate a set of `CallGraphNode`s
97100
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
98-
if N % 2 == 0
101+
if iseven(N)
99102
pow = _ispow24(N)
100103
if !isnothing(pow)
101104
push!(workspace, T[])
102105
push!(nodes, CallGraphNode(0, 0, pow == POW2 ? Pow2FFT() : Pow4FFT(), N, s_in, s_out))
103106
return 1
104107
end
105108
end
109+
if N % 3 == 0
110+
if _ispow(N, 3)
111+
push!(workspace, T[])
112+
push!(nodes, CallGraphNode(0, 0, Pow3FFT(), N, s_in, s_out))
113+
return 1
114+
end
115+
end
106116
if isprime(N)
107117
push!(workspace, T[])
108118
push!(nodes, CallGraphNode(0,0, DFT(),N, s_in, s_out))
@@ -111,6 +121,8 @@ function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{
111121
Ns = [first(x) for x in collect(factor(N)) for _ in 1:last(x)]
112122
if Ns[1] == 2
113123
N1 = prod(Ns[Ns .== 2])
124+
elseif Ns[1] == 3
125+
N1 = prod(Ns[Ns .== 3])
114126
else
115127
# Greedy search for closest factor of N to sqrt(N)
116128
Nsqrt = sqrt(N)

test/onedim/complex_backward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 16, 100]
2+
test_nums = [8, 11, 15, 16, 27, 100]
33
@testset "backward" begin
44
for N in test_nums
55
x = ones(ComplexF64, N)

test/onedim/complex_forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 16, 100]
2+
test_nums = [8, 11, 15, 16, 27, 100]
33
@testset verbose = true " forward" begin
44
for N in test_nums
55
x = ones(ComplexF64, N)

test/onedim/real_backward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 16, 100]
2+
test_nums = [8, 11, 15, 16, 27, 100]
33
@testset "backward" begin
44
for N in test_nums
55
x = ones(Float64, N)

test/onedim/real_forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 16, 100]
2+
test_nums = [8, 11, 15, 16, 27, 100]
33
@testset verbose = true " forward" begin
44
for N in test_nums
55
x = ones(Float64, N)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Random.seed!(1)
2828
@test x x2 atol=1e-12
2929
end
3030
end
31-
@testset verbose = true "2D" begin
31+
@testset verbose = false "2D" begin
3232
@testset verbose = true "Complex" begin
3333
include("twodim/complex_forward.jl")
3434
include("twodim/complex_backward.jl")

test/twodim/complex_backward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 16, 100]
2+
test_nums = [8, 11, 15, 16, 27, 100]
33
@testset "backward" begin
44
for N in test_nums
55
x = ones(ComplexF64, N, N)

test/twodim/complex_forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 16, 100]
2+
test_nums = [8, 11, 15, 16, 27, 100]
33
@testset " forward" begin
44
for N in test_nums
55
x = ones(ComplexF64, N, N)

test/twodim/real_backward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using FFTA, Test
2-
test_nums = [8, 11, 15, 16, 100]
2+
test_nums = [8, 11, 15, 16, 27, 100]
33
@testset "backward" begin
44
for N in test_nums
55
x = ones(Float64, N, N)

0 commit comments

Comments
 (0)