Skip to content

Commit e6bd067

Browse files
committed
Move callgraph out
1 parent d3a1f8b commit e6bd067

File tree

3 files changed

+154
-154
lines changed

3 files changed

+154
-154
lines changed

src/FFTA.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,62 @@ using Primes, DocStringExtensions, LoopVectorization
44
import Base: getindex
55
export fft, bfft
66

7+
include("callgraph.jl")
78
include("algos.jl")
89

10+
function fft(x::AbstractVector{T}) where {T}
11+
y = similar(x)
12+
g = CallGraph{T}(length(x))
13+
fft!(y, x, Val(FFT_FORWARD), g[1].type, g, 1)
14+
y
15+
end
16+
17+
function fft(x::AbstractVector{T}) where {T <: Real}
18+
y = similar(x, Complex{T})
19+
g = CallGraph{Complex{T}}(length(x))
20+
fft!(y, x, Val(FFT_FORWARD), g[1].type, g, 1)
21+
y
22+
end
23+
24+
function fft(x::AbstractMatrix{T}) where {T}
25+
M,N = size(x)
26+
y1 = similar(x)
27+
y2 = similar(x)
28+
g1 = CallGraph{T}(size(x,1))
29+
g2 = CallGraph{T}(size(x,2))
30+
31+
for k in 1:N
32+
@views fft!(y1[:,k], x[:,k], Val(FFT_FORWARD), g1[1].type, g1, 1)
33+
end
34+
35+
for k in 1:M
36+
@views fft!(y2[k,:], y1[k,:], Val(FFT_FORWARD), g2[1].type, g2, 1)
37+
end
38+
y2
39+
end
40+
41+
function bfft(x::AbstractVector{T}) where {T}
42+
y = similar(x)
43+
g = CallGraph{T}(length(x))
44+
fft!(y, x, Val(FFT_BACKWARD), g[1].type, g, 1)
45+
y
46+
end
47+
48+
function bfft(x::AbstractMatrix{T}) where {T}
49+
M,N = size(x)
50+
y1 = similar(x)
51+
y2 = similar(x)
52+
g1 = CallGraph{T}(size(x,1))
53+
g2 = CallGraph{T}(size(x,2))
54+
55+
for k in 1:N
56+
@views fft!(y1[:,k], x[:,k], Val(FFT_BACKWARD), g1[1].type, g1, 1)
57+
end
58+
59+
for k in 1:M
60+
@views fft!(y2[k,:], y1[k,:], Val(FFT_BACKWARD), g2[1].type, g2, 1)
61+
end
62+
y2
63+
end
64+
965
end

src/algos.jl

Lines changed: 0 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
@enum Direction FFT_FORWARD FFT_BACKWARD
2-
abstract type AbstractFFTType end
3-
41
function alternatingSum(x::AbstractVector{T}) where T
52
y = x[1]
63
@turbo for i in 2:length(x)
@@ -9,157 +6,6 @@ function alternatingSum(x::AbstractVector{T}) where T
96
y
107
end
118

12-
# Represents a Composite Cooley-Tukey FFT
13-
struct CompositeFFT <: AbstractFFTType end
14-
15-
# Represents a Radix-2 Cooley-Tukey FFT
16-
struct Pow2FFT <: AbstractFFTType end
17-
18-
# Represents an O(N²) DFT
19-
struct DFT <: AbstractFFTType end
20-
21-
"""
22-
$(TYPEDSIGNATURES)
23-
Node of a call graph
24-
25-
# Arguments
26-
`left::Int`- Offset to the left child node
27-
`right::Int`- Offset to the right child node
28-
`type::AbstractFFTType`- Object representing the type of FFT
29-
`sz::Int`- Size of this FFT
30-
31-
# Examples
32-
```julia
33-
julia> CallGraphNode(0, 0, Pow2FFT(), 8)
34-
```
35-
"""
36-
struct CallGraphNode
37-
left::Int
38-
right::Int
39-
type::AbstractFFTType
40-
sz::Int
41-
end
42-
43-
"""
44-
$(TYPEDSIGNATURES)
45-
Object representing a graph of FFT Calls
46-
47-
# Arguments
48-
`nodes::Vector{CallGraphNode}`- Nodes keeping track of the graph
49-
`workspace::Vector{Vector{T}}`- Preallocated Workspace
50-
51-
# Examples
52-
```julia
53-
julia> CallGraph{ComplexF64}(CallGraphNode[], Vector{T}[])
54-
```
55-
"""
56-
struct CallGraph{T<:Complex}
57-
nodes::Vector{CallGraphNode}
58-
workspace::Vector{Vector{T}}
59-
end
60-
61-
# Get the node in the graph at index i
62-
Base.getindex(g::CallGraph{T}, i::Int) where {T} = g.nodes[i]
63-
64-
# Get the left child of the node at index `i`
65-
leftNode(g::CallGraph, i::Int) = g[i+g[i].left]
66-
67-
# Get the right child of the node at index `i`
68-
rightNode(g::CallGraph, i::Int) = g[i+g[i].right]
69-
70-
# Recursively instantiate a set of `CallGraphNode`s
71-
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}})::Int where {T}
72-
facs = factor(N)
73-
Ns = [first(x) for x in collect(facs) for _ in 1:last(x)]
74-
if length(Ns) == 1 || Ns[end] == 2
75-
push!(workspace, T[])
76-
push!(nodes, CallGraphNode(0,0,Ns[end] == 2 ? Pow2FFT() : DFT(),N))
77-
return 1
78-
end
79-
80-
if Ns[1] == 2
81-
N1 = prod(Ns[Ns .== 2])
82-
else
83-
# Greedy search for closest factor of N to sqrt(N)
84-
Nsqrt = sqrt(N)
85-
N_cp = cumprod(Ns[end:-1:1])[end:-1:1]
86-
N_prox = abs.(N_cp .- Nsqrt)
87-
_,N1_idx = findmin(N_prox)
88-
N1 = N_cp[N1_idx]
89-
end
90-
N2 = N ÷ N1
91-
push!(nodes, CallGraphNode(0,0,DFT(),N))
92-
sz = length(nodes)
93-
push!(workspace, Vector{T}(undef, N))
94-
left_len = CallGraphNode!(nodes, N1, workspace)
95-
right_len = CallGraphNode!(nodes, N2, workspace)
96-
nodes[sz] = CallGraphNode(1, 1 + left_len, CompositeFFT(), N)
97-
return 1 + left_len + right_len
98-
end
99-
100-
# Instantiate a CallGraph from a number `N`
101-
function CallGraph{T}(N::Int) where {T}
102-
nodes = CallGraphNode[]
103-
workspace = Vector{Vector{T}}()
104-
CallGraphNode!(nodes, N, workspace)
105-
CallGraph(nodes, workspace)
106-
end
107-
108-
function fft(x::AbstractVector{T}) where {T}
109-
y = similar(x)
110-
g = CallGraph{T}(length(x))
111-
fft!(y, x, Val(FFT_FORWARD), g[1].type, g, 1)
112-
y
113-
end
114-
115-
function fft(x::AbstractVector{T}) where {T <: Real}
116-
y = similar(x, Complex{T})
117-
g = CallGraph{Complex{T}}(length(x))
118-
fft!(y, x, Val(FFT_FORWARD), g[1].type, g, 1)
119-
y
120-
end
121-
122-
function fft(x::AbstractMatrix{T}) where {T}
123-
M,N = size(x)
124-
y1 = similar(x)
125-
y2 = similar(x)
126-
g1 = CallGraph{T}(size(x,1))
127-
g2 = CallGraph{T}(size(x,2))
128-
129-
for k in 1:N
130-
@views fft!(y1[:,k], x[:,k], Val(FFT_FORWARD), g1[1].type, g1, 1)
131-
end
132-
133-
for k in 1:M
134-
@views fft!(y2[k,:], y1[k,:], Val(FFT_FORWARD), g2[1].type, g2, 1)
135-
end
136-
y2
137-
end
138-
139-
function bfft(x::AbstractVector{T}) where {T}
140-
y = similar(x)
141-
g = CallGraph{T}(length(x))
142-
fft!(y, x, Val(FFT_BACKWARD), g[1].type, g, 1)
143-
y
144-
end
145-
146-
function bfft(x::AbstractMatrix{T}) where {T}
147-
M,N = size(x)
148-
y1 = similar(x)
149-
y2 = similar(x)
150-
g1 = CallGraph{T}(size(x,1))
151-
g2 = CallGraph{T}(size(x,2))
152-
153-
for k in 1:N
154-
@views fft!(y1[:,k], x[:,k], Val(FFT_BACKWARD), g1[1].type, g1, 1)
155-
end
156-
157-
for k in 1:M
158-
@views fft!(y2[k,:], y1[k,:], Val(FFT_BACKWARD), g2[1].type, g2, 1)
159-
end
160-
y2
161-
end
162-
1639
fft!(out::AbstractVector{T}, in::AbstractVector{T}, ::Val{<:Direction}, ::AbstractFFTType, ::CallGraph{T}, ::Int) where {T} = nothing
16410

16511
function (g::CallGraph{T})(out::AbstractVector{T}, in::AbstractVector{U}, v::Val{FFT_FORWARD}, t::AbstractFFTType, idx::Int) where {T,U}

src/callgraph.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
@enum Direction FFT_FORWARD FFT_BACKWARD
2+
abstract type AbstractFFTType end
3+
4+
# Represents a Composite Cooley-Tukey FFT
5+
struct CompositeFFT <: AbstractFFTType end
6+
7+
# Represents a Radix-2 Cooley-Tukey FFT
8+
struct Pow2FFT <: AbstractFFTType end
9+
10+
# Represents an O(N²) DFT
11+
struct DFT <: AbstractFFTType end
12+
13+
"""
14+
$(TYPEDSIGNATURES)
15+
Node of a call graph
16+
17+
# Arguments
18+
`left::Int`- Offset to the left child node
19+
`right::Int`- Offset to the right child node
20+
`type::AbstractFFTType`- Object representing the type of FFT
21+
`sz::Int`- Size of this FFT
22+
23+
# Examples
24+
```julia
25+
julia> CallGraphNode(0, 0, Pow2FFT(), 8)
26+
```
27+
"""
28+
struct CallGraphNode
29+
left::Int
30+
right::Int
31+
type::AbstractFFTType
32+
sz::Int
33+
end
34+
35+
"""
36+
$(TYPEDSIGNATURES)
37+
Object representing a graph of FFT Calls
38+
39+
# Arguments
40+
`nodes::Vector{CallGraphNode}`- Nodes keeping track of the graph
41+
`workspace::Vector{Vector{T}}`- Preallocated Workspace
42+
43+
# Examples
44+
```julia
45+
julia> CallGraph{ComplexF64}(CallGraphNode[], Vector{T}[])
46+
```
47+
"""
48+
struct CallGraph{T<:Complex}
49+
nodes::Vector{CallGraphNode}
50+
workspace::Vector{Vector{T}}
51+
end
52+
53+
# Get the node in the graph at index i
54+
Base.getindex(g::CallGraph{T}, i::Int) where {T} = g.nodes[i]
55+
56+
# Get the left child of the node at index `i`
57+
leftNode(g::CallGraph, i::Int) = g[i+g[i].left]
58+
59+
# Get the right child of the node at index `i`
60+
rightNode(g::CallGraph, i::Int) = g[i+g[i].right]
61+
62+
# Recursively instantiate a set of `CallGraphNode`s
63+
function CallGraphNode!(nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}})::Int where {T}
64+
facs = factor(N)
65+
Ns = [first(x) for x in collect(facs) for _ in 1:last(x)]
66+
if length(Ns) == 1 || Ns[end] == 2
67+
push!(workspace, T[])
68+
push!(nodes, CallGraphNode(0,0,Ns[end] == 2 ? Pow2FFT() : DFT(),N))
69+
return 1
70+
end
71+
72+
if Ns[1] == 2
73+
N1 = prod(Ns[Ns .== 2])
74+
else
75+
# Greedy search for closest factor of N to sqrt(N)
76+
Nsqrt = sqrt(N)
77+
N_cp = cumprod(Ns[end:-1:1])[end:-1:1]
78+
N_prox = abs.(N_cp .- Nsqrt)
79+
_,N1_idx = findmin(N_prox)
80+
N1 = N_cp[N1_idx]
81+
end
82+
N2 = N ÷ N1
83+
push!(nodes, CallGraphNode(0,0,DFT(),N))
84+
sz = length(nodes)
85+
push!(workspace, Vector{T}(undef, N))
86+
left_len = CallGraphNode!(nodes, N1, workspace)
87+
right_len = CallGraphNode!(nodes, N2, workspace)
88+
nodes[sz] = CallGraphNode(1, 1 + left_len, CompositeFFT(), N)
89+
return 1 + left_len + right_len
90+
end
91+
92+
# Instantiate a CallGraph from a number `N`
93+
function CallGraph{T}(N::Int) where {T}
94+
nodes = CallGraphNode[]
95+
workspace = Vector{Vector{T}}()
96+
CallGraphNode!(nodes, N, workspace)
97+
CallGraph(nodes, workspace)
98+
end

0 commit comments

Comments
 (0)