Skip to content

Commit 1f752ee

Browse files
committed
Fix some performance type issue with conv
1 parent 151ab68 commit 1f752ee

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/convolutions.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,23 @@ function plan_conv(u::AbstractArray{T1, N}, v::AbstractArray{T2, M}, dims=ntuple
136136
# FFTW.MEASURE flag might overwrite input! Hence copy!
137137
if (:flags in keys(kwargs) &&
138138
(getindex(kwargs, :flags) == FFTW.MEASURE || getindex(kwargs, :flags) == FFTW.PATIENT))
139-
P = plan(copy(u), dims; kwargs...)
139+
plan(copy(u), dims; kwargs...)
140140
else
141-
P = plan(u, dims; kwargs...)
141+
plan(u, dims; kwargs...)
142142
end
143143
end
144-
P_inv = inv(P)
145144

146145
v_ft = fft_or_rfft(T1)(v, dims)
147146
# construct the efficient conv function
148147
# P and P_inv can be understood like matrices
149148
# but their computation is fast
150-
conv(u, v_ft=v_ft) = p_conv_aux(P, P_inv, u, v_ft)
149+
conv = let P = P,
150+
P_inv = inv(P),
151+
# put a different name here! See https://discourse.julialang.org/t/type-issue-with-captured-variables-let-workaround-failed/85661
152+
v_ft = v_ft
153+
conv(u, v_ft=v_ft) = p_conv_aux(P, P_inv, u, v_ft)
154+
end
155+
151156
return v_ft, conv
152157
end
153158

@@ -211,8 +216,7 @@ function plan_conv_psf(u::AbstractArray{T, N}, psf::AbstractArray{T, M}, dims=nt
211216
end
212217

213218
function p_conv_aux(P, P_inv, u, v_ft)
214-
tmp = (P_inv.p * ((P * u) .* v_ft .* P_inv.scale))
215-
return tmp
219+
return (P_inv.p * ((P * u) .* v_ft .* P_inv.scale))
216220
end
217221

218222
function ChainRulesCore.rrule(::typeof(p_conv_aux), P, P_inv, u, v)

0 commit comments

Comments
 (0)