Skip to content

Commit 24cfc87

Browse files
committed
Add power of three implementation
1 parent 5f83277 commit 24cfc87

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

src/algos.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,51 @@ function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, sta
178178
s_out = root.s_out
179179
fft_pow4!(out, in, N, start_out, s_out, start_in, s_in, d)
180180
end
181+
182+
function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction, plus120::T, minus120::T) where {T, U}
183+
if N == 3
184+
@muladd out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in]
185+
@muladd out[start_out + stride_out] = in[start_in] + in[start_in + stride_in]*plus120 + in[start_in + 2*stride_in]*minus120
186+
@muladd out[start_out + 2*stride_out] = in[start_in] + in[start_in + stride_in]*minus120 + in[start_in + 2*stride_in]*plus120
187+
return
188+
end
189+
190+
# Size of subproblem
191+
Nprime = N/3
192+
193+
ds = direction_sign(d)
194+
195+
# Dividing into subproblems
196+
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, d, plus120, minus120)
197+
fft_pow3!(out, in, Nprime, start_out + N_prime*stride_out, stride_out, start_in + stride_in, stride_in*3, d, plus120, minus120)
198+
fft_pow3!(out, in, Nprime, start_out + 2*N_prime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, d, plus120, minus120)
199+
200+
w1 = convert(T, cispi(ds*2/N))
201+
w2 = convert(T, cispi(ds*4/N))
202+
wk1 = wk2 = one(T)
203+
for k in 0:Nprime-1
204+
@muladd k0 = start_out + k*stride_out
205+
@muladd k1 = start_out + (k+Nprime)*stride_out
206+
@muladd k2 = start_out + (k+2*Nprime)*stride_out
207+
y_k0, y_k1, y_k2 = out[k0], out[k1], out[k2]
208+
@muladd out[k0] = y_k0 + y_k1*wk1 + y_k2*wk2
209+
@muladd out[k1] = y_k0 + y_k1*wk1*plus120 + y_k2*wk2*minus120
210+
@muladd out[k2] = y_k0 + y_k1*wk1*minus120 + y_k2*wk2*plus120
211+
wk1 *= w1
212+
wk2 *= w2
213+
end
214+
end
215+
216+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::Pow3FFT, g::CallGraph{T}, idx::Int) where {T,U}
217+
root = g[idx]
218+
N = root.sz
219+
s_in = root.s_in
220+
s_out = root.s_out
221+
p_120 = convert(T, cispi(2/3))
222+
m_120 = convert(T, cispi(4/3))
223+
if d == FFT_FORWARD
224+
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, d, m_120, p_120)
225+
else
226+
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, d, p_120, m_120)
227+
end
228+
end

src/callgraph.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ struct CompositeFFT <: AbstractFFTType end
99
# Represents a Radix-2 Cooley-Tukey FFT
1010
struct Pow2FFT <: AbstractFFTType end
1111

12+
# Represents a Radix-3 Cooley-Tukey FFT
13+
struct Pow3FFT <: AbstractFFTType end
14+
1215
# Represents a Radix-4 Cooley-Tukey FFT
1316
struct Pow4FFT <: AbstractFFTType end
1417

@@ -95,7 +98,7 @@ end
9598

9699
# Recursively instantiate a set of `CallGraphNode`s
97100
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
98-
if N % 2 == 0
101+
if iseven(N)
99102
pow = _ispow24(N)
100103
if !isnothing(pow)
101104
push!(workspace, T[])

0 commit comments

Comments
 (0)