Skip to content

Commit fe3b06a

Browse files
committed
Add plan_inv implementation for adjoint plan and test it
1 parent 2a2d685 commit fe3b06a

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/definitions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,4 +654,6 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle)
654654
return scale ./ N .* (p.p \ x)
655655
end
656656

657+
# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
658+
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
657659
inv(p::AdjointPlan) = adjoint(inv(p.p))

test/runtests.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,14 @@ end
263263
@test (P')' === P # test adjoint of adjoint
264264
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
265265
@test dot(y, P * x) dot(P' * y, x) # test validity of adjoint
266-
@test dot(y, P \ x) dot(P' \ y, x)
266+
@test dot(y, P \ x) dot(P' \ y, x) # test inv of adjoint
267+
@test dot(y, P \ x) dot(AbstractFFTs.plan_inv(P') * y, x) # test plan_inv of adjoint
267268
Pinv = plan_ifft(y)
268269
@test (Pinv')' * y == Pinv * y
269270
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
270271
@test dot(x, Pinv * y) dot(Pinv' * x, y)
271272
@test dot(x, Pinv \ y) dot(Pinv' \ x, y)
273+
@test dot(x, Pinv \ y) dot(AbstractFFTs.plan_inv(Pinv') * x, y)
272274
@test_throws MethodError mul!(x, P', y)
273275
end
274276
end
@@ -281,14 +283,17 @@ end
281283
P = plan_rfft(x, dims)
282284
y = randn(ComplexF64, size(P * x))
283285
@test (P')' * x == P * x
284-
@test size(P') == AbstractFFTs.output_size(P)
286+
@test size(P') == AbstractFFTs.output_size(P)
285287
@test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) dot(P' * y, x)
286288
@test dot(real.(y), real.(P' \ x)) + dot(imag.(y), imag.(P' \ x)) dot(P \ y, x)
289+
@test dot(real.(y), real.(AbstractFFTs.plan_inv(P') * x)) +
290+
dot(imag.(y), imag.(AbstractFFTs.plan_inv(P') * x)) dot(P \ y, x)
287291
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
288292
@test (Pinv')' * y == Pinv * y
289293
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
290294
@test dot(x, Pinv * y) dot(real.(y), real.(Pinv' * x)) + dot(imag.(y), imag.(Pinv' * x))
291295
@test dot(x, Pinv' \ y) dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x))
296+
@test dot(x, AbstractFFTs.plan_inv(Pinv') * y) dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x))
292297
end
293298
end
294299
end

0 commit comments

Comments
 (0)