Skip to content

Commit 111eda5

Browse files
authored
Increasing robustness of adjoint plan optimizations (#123)
* Add some safeguards to prioritize type stability in adjoint plans * Test adjoint plans with float32's * Add subtype * Apply suggestions from code review * Rename dummy -> wrapper
1 parent a67bf15 commit 111eda5

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

src/definitions.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -681,13 +681,8 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
681681
dims = fftdims(p)
682682
N = normalization(T, size(p), dims)
683683
pinv = inv(p)
684-
# Optimization: when pinv is a ScaledPlan, check if we can avoid a loop over x.
685-
# Even if not, ensure that we do only one pass by combining the normalization with the plan.
686-
if pinv isa ScaledPlan && pinv.scale == N
687-
return pinv.p * x
688-
else
689-
return (1/N * pinv) * x
690-
end
684+
# Ensure that we do only one pass over the array by combining the normalization with the plan.
685+
return (inv(N) * pinv) * x
691686
end
692687

693688
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
@@ -698,7 +693,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<
698693
pinv = inv(p)
699694
n = size(pinv, halfdim)
700695
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
701-
scale = pinv isa ScaledPlan ? pinv.scale / 2N : 1 / 2N
696+
scale = pinv isa ScaledPlan ? pinv.scale / 2N : inv(2N)
702697
twoscale = 2 * scale
703698
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
704699
y = map(x, CartesianIndices(x)) do xj, j
@@ -721,7 +716,7 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T
721716
pinv = inv(p)
722717
d = size(pinv, halfdim)
723718
# Optimization: when pinv is a ScaledPlan, fuse the scaling into our map to ensure we do not loop over x twice.
724-
scale = pinv isa ScaledPlan ? pinv.scale / N : 1 / N
719+
scale = pinv isa ScaledPlan ? pinv.scale / N : inv(N)
725720
twoscale = 2 * scale
726721
unscaled_pinv = pinv isa ScaledPlan ? pinv.p : pinv
727722
y = unscaled_pinv * x

test/TestPlans.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,4 +278,20 @@ AbstractFFTs.plan_inv(p::InplaceTestPlan) = InplaceTestPlan(AbstractFFTs.plan_in
278278
# Don't cache inverse of inplace wrapper plan (only inverse of inner plan)
279279
Base.inv(p::InplaceTestPlan) = InplaceTestPlan(inv(p.plan))
280280

281+
# A wrapper plan whose inverse is not an instance of AbstractFFTs.ScaledPlan, for testing purposes
282+
283+
struct WrapperTestPlan{T,P<:Plan{T}} <: Plan{T}
284+
plan::P
285+
end
286+
287+
Base.size(p::WrapperTestPlan) = size(p.plan)
288+
Base.ndims(p::WrapperTestPlan) = ndims(p.plan)
289+
AbstractFFTs.fftdims(p::WrapperTestPlan) = fftdims(p.plan)
290+
AbstractFFTs.AdjointStyle(p::WrapperTestPlan) = AbstractFFTs.AdjointStyle(p.plan)
291+
292+
Base.:*(p::WrapperTestPlan, x::AbstractArray) = p.plan * x
293+
294+
AbstractFFTs.plan_inv(p::WrapperTestPlan) = WrapperTestPlan(AbstractFFTs.plan_inv(p.plan))
295+
Base.inv(p::WrapperTestPlan) = WrapperTestPlan(inv(p.plan))
296+
281297
end

test/runtests.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,39 @@ end
145145
end
146146
end
147147

148+
@testset "Adjoint plan on single-precision" begin
149+
# fft
150+
p = plan_fft(zeros(ComplexF32, 3))
151+
u = rand(ComplexF32, 3)
152+
@test eltype(p' * (p * u)) == eltype(u)
153+
# rfft
154+
p = plan_rfft(zeros(Float32, 3))
155+
u = rand(Float32, 3)
156+
@test eltype(p' * (p * u)) == eltype(u)
157+
# brfft
158+
p = plan_brfft(zeros(ComplexF32, 3), 5)
159+
u = rand(ComplexF32, 3)
160+
@test eltype(p' * (p * u)) == eltype(u)
161+
end
162+
163+
@testset "Adjoint plan application when plan inverse is not a ScaledPlan" begin
164+
# fft
165+
p0 = plan_fft(zeros(ComplexF64, 3))
166+
p = TestPlans.WrapperTestPlan(p0)
167+
u = rand(ComplexF64, 3)
168+
@test p' * u p0' * u
169+
# rfft
170+
p0 = plan_rfft(zeros(3))
171+
p = TestPlans.WrapperTestPlan(p0)
172+
u = rand(ComplexF64, 2)
173+
@test p' * u p0' * u
174+
# brfft
175+
p0 = plan_brfft(zeros(ComplexF64, 3), 5)
176+
p = TestPlans.WrapperTestPlan(p0)
177+
u = rand(Float64, 5)
178+
@test p' * u p0' * u
179+
end
180+
148181
@testset "ChainRules" begin
149182
@testset "shift functions" begin
150183
for x in (randn(3), randn(3, 4), randn(3, 4, 5))

0 commit comments

Comments
 (0)