Skip to content

Commit ed96cf5

Browse files
committed
Working on striding
1 parent dfc8c7d commit ed96cf5

File tree

4 files changed

+94
-71
lines changed

4 files changed

+94
-71
lines changed

src/FFTA.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ include("algos.jl")
1010
function fft(x::AbstractVector{T}) where {T}
1111
y = similar(x)
1212
g = CallGraph{T}(length(x))
13-
fft!(y, x, FFT_FORWARD(), g[1].type, g, 1)
13+
fft!(y, x, 1, 1, FFT_FORWARD(), g[1].type, g, 1)
1414
y
1515
end
1616

1717
function fft(x::AbstractVector{T}) where {T <: Real}
1818
y = similar(x, Complex{T})
1919
g = CallGraph{Complex{T}}(length(x))
20-
fft!(y, x, FFT_FORWARD(), g[1].type, g, 1)
20+
fft!(y, x, 1, 1, FFT_FORWARD(), g[1].type, g, 1)
2121
y
2222
end
2323

@@ -29,19 +29,19 @@ function fft(x::AbstractMatrix{T}) where {T}
2929
g2 = CallGraph{T}(size(x,2))
3030

3131
for k in 1:N
32-
@views fft!(y1[:,k], x[:,k], FFT_FORWARD(), g1[1].type, g1, 1)
32+
@views fft!(y1[:,k], x[:,k], 1, 1, FFT_FORWARD(), g1[1].type, g1, 1)
3333
end
3434

3535
for k in 1:M
36-
@views fft!(y2[k,:], y1[k,:], FFT_FORWARD(), g2[1].type, g2, 1)
36+
@views fft!(y2[k,:], y1[k,:], 1, 1, FFT_FORWARD(), g2[1].type, g2, 1)
3737
end
3838
y2
3939
end
4040

4141
function bfft(x::AbstractVector{T}) where {T}
4242
y = similar(x)
4343
g = CallGraph{T}(length(x))
44-
fft!(y, x, FFT_BACKWARD(), g[1].type, g, 1)
44+
fft!(y, x, 1, 1, FFT_BACKWARD(), g[1].type, g, 1)
4545
y
4646
end
4747

@@ -53,11 +53,11 @@ function bfft(x::AbstractMatrix{T}) where {T}
5353
g2 = CallGraph{T}(size(x,2))
5454

5555
for k in 1:N
56-
@views fft!(y1[:,k], x[:,k], FFT_BACKWARD(), g1[1].type, g1, 1)
56+
@views fft!(y1[:,k], x[:,k], 1, 1, FFT_BACKWARD(), g1[1].type, g1, 1)
5757
end
5858

5959
for k in 1:M
60-
@views fft!(y2[k,:], y1[k,:], FFT_BACKWARD(), g2[1].type, g2, 1)
60+
@views fft!(y2[k,:], y1[k,:], 1, 1, FFT_BACKWARD(), g2[1].type, g2, 1)
6161
end
6262
y2
6363
end

src/algos.jl

Lines changed: 76 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function alternatingSum(x::AbstractVector{T}) where T
66
y
77
end
88

9-
fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Direction, ::AbstractFFTType, ::CallGraph{T}, ::Int) where {T} = nothing
9+
fft!(::AbstractVector{T}, ::AbstractVector{T}, ::Int, ::Int, ::Direction, ::AbstractFFTType, ::CallGraph{T}, ::Int) where {T} = nothing
1010

1111
@inline function direction_sign(::FFT_BACKWARD)
1212
1
@@ -16,60 +16,73 @@ end
1616
-1
1717
end
1818

19-
function (g::CallGraph{T})(out::AbstractVector{T}, in::AbstractVector{U}, v::Direction, t::AbstractFFTType, idx::Int) where {T,U}
20-
fft!(out, in, v, t, g, idx)
19+
function (g::CallGraph{T})(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, v::Direction, t::AbstractFFTType, idx::Int) where {T,U}
20+
fft!(out, in, start_out, start_in, v, t, g, idx)
2121
end
2222

23-
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, d::Direction, ::CompositeFFT, g::CallGraph{T}, idx::Int) where {T,U}
24-
N = length(out)
25-
left = leftNode(g,idx)
26-
right = rightNode(g,idx)
23+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::CompositeFFT, g::CallGraph{T}, idx::Int) where {T,U}
24+
root = g[idx]
25+
left_idx = idx + root.left
26+
right_idx = idx + root.right
27+
left = g[left_idx]
28+
right = g[right_idx]
29+
N = root.sz
2730
N1 = left.sz
2831
N2 = right.sz
32+
s_in = root.s_in
33+
s_out = root.s_out
34+
@info "" N N1 N2 s_in s_out start_in start_out
2935

3036
w1 = convert(T, cispi(direction_sign(d)*2/N))
3137
wj1 = one(T)
3238
tmp = g.workspace[idx]
33-
@inbounds for j1 in 1:N1
39+
@inbounds for j1 in 0:N1-1
3440
wk2 = wj1;
35-
@views g(tmp[(N2*(j1-1) + 1):(N2*j1)], in[j1:N1:end], d, right.type, idx + g[idx].right)
36-
j1 > 1 && @inbounds for k2 in 2:N2
37-
tmp[N2*(j1-1) + k2] *= wk2
41+
g(tmp, in, N2*j1+1, start_in + j1*s_in, d, right.type, right_idx)
42+
j1 > 0 && @inbounds for k2 in 2:N2
43+
tmp[N2*j1 + k2] *= wk2
3844
wk2 *= wj1
3945
end
4046
wj1 *= w1
4147
end
4248

4349
@inbounds for k2 in 1:N2
44-
@views g(out[k2:N2:end], tmp[k2:N2:end], d, left.type, idx + g[idx].left)
50+
g(out, tmp, start_out + (k2-1)*s_out, k2, d, left.type, left_idx)
4551
end
4652
end
4753

48-
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, d::Direction, a::Pow2FFT, b::CallGraph{T}, c::Int) where {T,U}
49-
fft_pow2!(out, in, d)
54+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::Pow2FFT, g::CallGraph{T}, idx::Int) where {T,U}
55+
root = g[idx]
56+
N = root.sz
57+
s_in = root.s_in
58+
s_out = root.s_out
59+
fft_pow2!(out, in, N, start_out, s_out, start_in, s_in, d)
5060
end
5161

5262
"""
5363
Power of 2 FFT in place, complex
5464
5565
"""
56-
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, d::Direction) where {T}
57-
N = length(out)
66+
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T}
67+
5868
if N == 2
59-
out[1] = in[1] + in[2]
60-
out[2] = in[1] - in[2]
69+
out[start_out] = in[start_in] + in[start_in + stride_in]
70+
out[start_out + stride_out] = in[start_in] - in[start_in + stride_in]
6171
return
6272
end
63-
fft_pow2!(@view(out[1:(end÷2)]), @view(in[1:2:end]), d)
64-
fft_pow2!(@view(out[(end÷2+1):end]), @view(in[2:2:end]), d)
73+
m = N ÷ 2
74+
75+
fft_pow2!(out, in, m, start_out , stride_out, start_in , stride_in*2, d)
76+
fft_pow2!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*2, d)
6577

6678
w1 = convert(T, cispi(direction_sign(d)*2/N))
6779
wj = one(T)
68-
m = N ÷ 2
69-
@inbounds for j in 1:m
70-
out_j = out[j]
71-
out[j] = out_j + wj*out[j+m]
72-
out[j+m] = out_j - wj*out[j+m]
80+
@inbounds for j in 0:m-1
81+
j1_out = start_out + j*stride_out
82+
j2_out = start_out + (j+m)*stride_out
83+
out_j = out[j1_out]
84+
out[j1_out] = out_j + wj*out[j2_out]
85+
out[j2_out] = out_j - wj*out[j2_out]
7386
wj *= w1
7487
end
7588
end
@@ -78,45 +91,52 @@ end
7891
Power of 2 FFT in place, real
7992
8093
"""
81-
function fft_pow2!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, d::Direction) where {T<:Real}
82-
N = length(out)
94+
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T<:Real}
8395
if N == 2
8496
out[1] = in[1] + in[2]
8597
out[2] = in[1] - in[2]
8698
return
8799
end
88-
fft_pow2!(@view(out[1:(end÷2)]), @view(in[1:2:end]), d)
89-
fft_pow2!(@view(out[(end÷2+1):end]), @view(in[2:2:end]), d)
100+
m = N ÷ 2
101+
fft_pow2!(out, in, m, start_out , stride_out, start_in , stride_in*2, d)
102+
fft_pow2!(out, in, m, start_out + m, stride_out, start_in + 1, stride_in*2, d)
90103

91104
w1 = convert(Complex{T}, cispi(direction_sign(d)*2/N))
92105
wj = one(Complex{T})
93-
m = N ÷ 2
94-
@inbounds @turbo for j in 2:m
95-
out[j] = out[j] + wj*out[j+m]
106+
@inbounds @turbo for j in 1:m-1
107+
j1_out = start_out + j*stride_out
108+
j2_out = start_out + (j+m)*stride_out
109+
out[j1_out] = out[j1_out] + wj*out[j2_out]
96110
wj *= w1
97111
end
98-
@inbounds @turbo for j in 2:m
99-
out[m+j] = conj(out[m-j+2])
112+
@inbounds @turbo for j in 1:m-1
113+
j1_out = start_out + (j+m)*stride_out
114+
j2_out = start_out + (m-j+1)*stride_out
115+
out[j1_out] = conj(out[j2_out])
100116
end
101117
end
102118

103-
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, d::Direction) where {T}
104-
N = length(out)
119+
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T}
120+
@info "" start_out stride_out start_in stride_in N
105121
wn² = wn = w = convert(T, cispi(direction_sign(d)*2/N))
106122
wn_1 = one(T)
107123

108-
tmp = in[1]
124+
tmp = in[start_in]
109125
out .= tmp
110-
tmp = sum(in)
111-
out[1] = tmp
112-
126+
tmp = sum(@view in[start_in:stride_in:start_in+stride_in*(N-1)])
127+
out[start_out] = tmp
128+
113129
wk = wn²
114-
@inbounds for d in 2:N
115-
out[d] = in[d]*wk + out[d]
116-
@inbounds for k in (d+1):N
130+
@inbounds for d in 1:N-1
131+
d_in = start_in + d*stride_in
132+
d_out = start_out + d*stride_out
133+
out[d_out] = in[d_in]*wk + out[d_out]
134+
@inbounds for k in d:N-1
135+
k_in = start_in + k*stride_in
136+
k_out = start_out + k*stride_out
117137
wk *= wn
118-
out[d] = in[k]*wk + out[d]
119-
out[k] = in[d]*wk + out[k]
138+
out[d_out] = in[k_in]*wk + out[d_out]
139+
out[k_out] = in[d_in]*wk + out[k_out]
120140
end
121141
wn_1 = wn
122142
wn *= w
@@ -125,30 +145,30 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, d::Direction) w
125145
end
126146
end
127147

128-
function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, d::Direction) where {T<:Real}
129-
N = length(out)
148+
function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, d::Direction) where {T<:Real}
130149
halfN = N÷2
131150
wk = wkn = w = convert(Complex{T}, cispi(direction_sign(d)*2/N))
132151

133-
out[2:N] .= in[1]
134-
out[1] = sum(in)
135-
iseven(N) && (out[halfN+1] = alternatingSum(in))
152+
out[start_out + 1:stride_out:start_out+stride_out*N] .= in[1]
153+
out[1] = sum(@view in[start_in:stride_in:start_int+stride_out*N])
154+
iseven(N) && (out[start_out + stride_out*halfN] = alternatingSum(@view in[start_in:stride_in:start_int+stride_out*N]))
136155

137156
@inbounds for d in 2:halfN+1
138-
tmp = in[1]
157+
tmp = in[start_in]
139158
@inbounds for k in 2:N
140-
tmp += wkn*in[k]
159+
tmp += wkn*in[start_in + k*stride_in]
141160
wkn *= wk
142161
end
143-
out[d] = tmp
162+
out[start_out + d*stride_out] = tmp
144163
wk *= w
145164
wkn = wk
146165
end
147166
@inbounds @turbo for i in 0:halfN-1
148-
out[N-i] = conj(out[halfN-i])
167+
out[start_out + stride_out*(N-i)] = conj(out[start_out + stride_out*(halfN-i)])
149168
end
150169
end
151170

152-
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, d::Direction, ::DFT, ::CallGraph{T}, ::Int) where {T,U}
153-
fft_dft!(out, in, d)
171+
function fft!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, ::DFT, g::CallGraph{T}, idx::Int) where {T,U}
172+
root = g[idx]
173+
fft_dft!(out, in, root.sz, start_out, root.s_out, start_in, root.s_in, d)
154174
end

src/callgraph.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct CallGraphNode
3333
right::Int
3434
type::AbstractFFTType
3535
sz::Int
36+
s_in::Int
37+
s_out::Int
3638
end
3739

3840
"""
@@ -63,12 +65,12 @@ leftNode(g::CallGraph, i::Int) = g[i+g[i].left]
6365
rightNode(g::CallGraph, i::Int) = g[i+g[i].right]
6466

6567
# Recursively instantiate a set of `CallGraphNode`s
66-
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}})::Int where {T}
68+
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}}, s_in::Int, s_out::Int)::Int where {T}
6769
facs = factor(N)
6870
Ns = [first(x) for x in collect(facs) for _ in 1:last(x)]
6971
if length(Ns) == 1 || Ns[end] == 2
7072
push!(workspace, T[])
71-
push!(nodes, CallGraphNode(0,0,Ns[end] == 2 ? Pow2FFT() : DFT(),N))
73+
push!(nodes, CallGraphNode(0,0,Ns[end] == 2 ? Pow2FFT() : DFT(),N, s_in, s_out))
7274
return 1
7375
end
7476

@@ -83,19 +85,19 @@ function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{
8385
N1 = N_cp[N1_idx]
8486
end
8587
N2 = N ÷ N1
86-
push!(nodes, CallGraphNode(0,0,DFT(),N))
88+
push!(nodes, CallGraphNode(0,0,DFT(),N,s_in,s_out))
8789
sz = length(nodes)
8890
push!(workspace, Vector{T}(undef, N))
89-
left_len = CallGraphNode!(nodes, N1, workspace)
90-
right_len = CallGraphNode!(nodes, N2, workspace)
91-
nodes[sz] = CallGraphNode(1, 1 + left_len, CompositeFFT(), N)
91+
left_len = CallGraphNode!(nodes, N1, workspace, N2, N2*s_out)
92+
right_len = CallGraphNode!(nodes, N2, workspace, N1*s_in, 1)
93+
nodes[sz] = CallGraphNode(1, 1 + left_len, CompositeFFT(), N, s_in, s_out)
9294
return 1 + left_len + right_len
9395
end
9496

9597
# Instantiate a CallGraph from a number `N`
9698
function CallGraph{T}(N::Int) where {T}
9799
nodes = CallGraphNode[]
98100
workspace = Vector{Vector{T}}()
99-
CallGraphNode!(nodes, N, workspace)
101+
CallGraphNode!(nodes, N, workspace, 1, 1)
100102
CallGraph(nodes, workspace)
101103
end

test/ffta.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using FFTA
22

3-
for N in [8, 11, 15, 100]
3+
# for N in [8, 11, 15, 100]
4+
for N in [8, 11, 15]
45
x = zeros(ComplexF64, N)
56
x[1] = 1
67
y = fft(x)

0 commit comments

Comments
 (0)