Skip to content

Commit 23a4b1a

Browse files
committed
Works for vectors
1 parent f89120b commit 23a4b1a

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

src/plan.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import Base: *
2+
import LinearAlgebra: mul!
3+
14
struct FFTAInvPlan{T} <: Plan{T} end
25

36
@computed struct FFTAPlan{T<:Union{Real, Complex},N} <: Plan{T}
47
callgraph::NTuple{N, CallGraph{(T<:Real) ? Complex{T} : T}}
5-
region::NTuple{N, Int}
8+
region
69
dir::Direction
710
pinv::FFTAInvPlan{T}
811
end
@@ -12,12 +15,12 @@ function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan
1215
@assert N <= 2 "Only supports vectors and matrices"
1316
if N == 1
1417
g = CallGraph{T}(size(x,region[]))
15-
pinv = FFTAInvPlan()
18+
pinv = FFTAInvPlan{T}()
1619
return FFTAPlan{T,N}((g,), region, FFT_FORWARD, pinv)
1720
else
1821
g1 = CallGraph{T}(size(x,region[1]))
1922
g2 = CallGraph{T}(size(x,region[2]))
20-
pinv = FFTAInvPlan()
23+
pinv = FFTAInvPlan{T}()
2124
return FFTAPlan{T,N}((g1,g2), region, FFT_FORWARD, pinv)
2225
end
2326
end
@@ -26,11 +29,21 @@ function AbstractFFTs.plan_bfft(p::FFTAPlan{T,N}) where {T,N}
2629
return FFTAPlan{T,N}(p.callgraph, p.region, -p.dir, p.pinv)
2730
end
2831

29-
function LinearAlgebra.mul!(y, p::FFTAPlan, x)
30-
fft!(y, x, 1, 1, p.dir, p.callgraph[1].type, p.callgraph, 1)
32+
function LinearAlgebra.mul!(y::AbstractVector{T}, p::FFTAPlan{T,1}, x::AbstractVector{T}) where T
33+
fft!(y, x, 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
34+
end
35+
36+
function LinearAlgebra.mul!(y::AbstractArray{T,N}, p::FFTAPlan{T,1}, x::AbstractArray{T,N}) where {T,N}
37+
Rpre = CartesianIndices(size(x)[1:p.region-1])
38+
Rpost = CartesianIndices(size(x)[p.region+1:end])
39+
for Ipre in Rpre
40+
for Ipost in Rpost
41+
@views fft!(y[Ipre,:,Ipost], x[Ipre,:,Ipost], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
42+
end
43+
end
3144
end
3245

33-
function *(p::FFTAPlan, x)
46+
function *(p::FFTAPlan{T,1}, x::AbstractVector{T}) where {T<:Union{Real, Complex}}
3447
y = similar(x)
3548
LinearAlgebra.mul!(y, p, x)
3649
y

0 commit comments

Comments
 (0)