Skip to content

Commit bd5f7be

Browse files
committed
Remove unnecessary function call, adjust rfft's
1 parent 32a94de commit bd5f7be

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

src/plan.jl

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ struct FFTAPlan_cx{T,N} <: FFTAPlan{T,N}
1313
end
1414

1515
struct FFTAPlan_re{T,N} <: FFTAPlan{T,N}
16-
callgraph::NTuple{N, CallGraph{Complex{T}}}
16+
callgraph::NTuple{N, CallGraph{T}}
1717
region::Union{Int,AbstractVector{<:Int}}
1818
dir::Direction
1919
pinv::FFTAInvPlan{T}
20+
flen::Int
2021
end
2122

2223
function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex}
@@ -57,29 +58,30 @@ function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPla
5758
if N == 1
5859
g = CallGraph{Complex{T}}(size(x,region[]))
5960
pinv = FFTAInvPlan{T,N}()
60-
return FFTAPlan_re{T,N}((g,), region, FFT_FORWARD, pinv)
61+
return FFTAPlan_re{T,N}((g,), region, FFT_FORWARD, pinv, size(x,region[]))
6162
else
6263
sort!(region)
6364
g1 = CallGraph{Complex{T}}(size(x,region[1]))
6465
g2 = CallGraph{Complex{T}}(size(x,region[2]))
6566
pinv = FFTAInvPlan{T,N}()
66-
return FFTAPlan_re{T,N}((g1,g2), region, FFT_FORWARD, pinv)
67+
return FFTAPlan_re{T,N}((g1,g2), region, FFT_FORWARD, pinv, size(x,region[1]))
6768
end
6869
end
6970

70-
function AbstractFFTs.plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPlan_cx{T} where {T}
71+
function AbstractFFTs.plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPlan_re{T} where {T}
7172
N = length(region)
73+
@info "" x
7274
@assert N <= 2 "Only supports vectors and matrices"
7375
if N == 1
74-
g = CallGraph{Complex{T}}(size(x,region[]))
76+
g = CallGraph{Complex{T}}(len)
7577
pinv = FFTAInvPlan{T,N}()
76-
return FFTAPlan_cx{T,N}((g,), region, FFT_BACKWARD, pinv)
78+
return FFTAPlan_re{T,N}((g,), region, FFT_BACKWARD, pinv, len)
7779
else
7880
sort!(region)
79-
g1 = CallGraph{Complex{T}}(size(x,region[1]))
81+
g1 = CallGraph{Complex{T}}(len)
8082
g2 = CallGraph{Complex{T}}(size(x,region[2]))
8183
pinv = FFTAInvPlan{T,N}()
82-
return FFTAPlan_cx{T,N}((g1,g2), region, FFT_BACKWARD, pinv)
84+
return FFTAPlan_re{T,N}((g1,g2), region, FFT_BACKWARD, pinv, len)
8385
end
8486
end
8587

@@ -88,7 +90,7 @@ function AbstractFFTs.plan_bfft(p::FFTAPlan_cx{T,N}) where {T,N}
8890
end
8991

9092
function AbstractFFTs.plan_brfft(p::FFTAPlan_re{T,N}) where {T,N}
91-
return FFTAPlan_cx{T,N}(p.callgraph, p.region, -p.dir, p.pinv)
93+
return FFTAPlan_re{T,N}(p.callgraph, p.region, -p.dir, p.pinv, p.flen)
9294
end
9395

9496
function LinearAlgebra.mul!(y::AbstractVector{U}, p::FFTAPlan{T,1}, x::AbstractVector{T}) where {T,U}
@@ -105,19 +107,29 @@ function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan{T,1}, x::Abstract
105107
end
106108
end
107109

108-
function LinearAlgebra.mul!(y::AbstractMatrix{U}, p::FFTAPlan{T,1}, x::AbstractMatrix{T}) where {T,U}
109-
rows,cols = size(x)[p.region]
110-
y_tmp = similar(y)
111-
for k in 1:cols
112-
@views fft!(y_tmp[:,k], x[:,k], 1, 1, p.dir, p.callgraph[2][1].type, p.callgraph[2], 1)
113-
end
114-
115-
for k in 1:rows
116-
@views fft!(y[k,:], y_tmp[k,:], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
110+
function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,2}, x::AbstractArray{T,N}) where {T,U,N}
111+
R1 = CartesianIndices(size(x)[1:p.region[1]-1])
112+
R2 = CartesianIndices(size(x)[p.region[1]+1:p.region[2]-1])
113+
R3 = CartesianIndices(size(x)[p.region[2]+1:end])
114+
y_tmp = similar(y, axes(y)[p.region])
115+
for I1 in R1
116+
for I2 in R2
117+
for I3 in R3
118+
rows,cols = size(x)[p.region]
119+
for k in 1:cols
120+
@views fft!(y_tmp[:,k], x[I1,:,I2,k,I3], 1, 1, p.dir, p.callgraph[2][1].type, p.callgraph[2], 1)
121+
end
122+
123+
for k in 1:rows
124+
@views fft!(y[I1,k,I2,:,I3], y_tmp[k,:], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
125+
end
126+
end
127+
end
117128
end
118129
end
119130

120-
function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan{T,2}, x::AbstractArray{T,N}) where {T,U,N}
131+
132+
function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_re{T,2}, x::AbstractArray{T,N}) where {T,U,N}
121133
R1 = CartesianIndices(size(x)[1:p.region[1]-1])
122134
R2 = CartesianIndices(size(x)[p.region[1]+1:p.region[2]-1])
123135
R3 = CartesianIndices(size(x)[p.region[2]+1:end])
@@ -153,10 +165,10 @@ end
153165
function *(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Union{Real, Complex}}
154166
y = similar(x, T <: Real ? Complex{T} : T)
155167
LinearAlgebra.mul!(y, p, x)
156-
y
168+
y[1:end÷2 + 1]
157169
end
158170

159-
function *(p::FFTAPlan_re{T,N1}, x::AbstractArray{T,N2}) where {T<:Union{Real, Complex}, N1, N2}
171+
function *(p::FFTAPlan_re{T,N}, x::AbstractArray{T,2}) where {T<:Union{Real, Complex}, N}
160172
y = similar(x, T <: Real ? Complex{T} : T)
161173
LinearAlgebra.mul!(y, p, x)
162174
y

0 commit comments

Comments
 (0)