Skip to content

Commit ad71816

Browse files
committed
Implement AdjointPlans
1 parent 10e12af commit ad71816

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

src/definitions.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T
1212

1313
# size(p) should return the size of the input array for p
1414
size(p::Plan, d) = size(p)[d]
15+
output_size(p::Plan, d) = output_size(p)[d]
1516
ndims(p::Plan) = length(size(p))
1617
length(p::Plan) = prod(size(p))::Int
1718

@@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
255256
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)
256257

257258
size(p::ScaledPlan) = size(p.p)
259+
output_size(p::ScaledPlan) = output_size(p.p)
258260

259261
fftdims(p::ScaledPlan) = fftdims(p.p)
260262

@@ -576,3 +578,67 @@ Pre-plan an optimized real-input unnormalized transform, similar to
576578
the same as for [`brfft`](@ref).
577579
"""
578580
plan_brfft
581+
582+
##############################################################################
583+
584+
struct NoProjectionStyle end
585+
struct RealProjectionStyle end
586+
struct RealInverseProjectionStyle end
587+
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
588+
589+
function irfft_dim end
590+
591+
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
592+
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
593+
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p))
594+
_output_size(p::Plan, ::RealInverseProjectionStyle) = brfft_output_size(size(p), irfft_dim(p), region(p))
595+
596+
mutable struct AdjointPlan{T,P} <: Plan{T}
597+
p::P
598+
pinv::Plan
599+
AdjointPlan{T,P}(p) where {T,P} = new(p)
600+
end
601+
602+
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
603+
Base.adjoint(p::AdjointPlan{T}) where {T} = p.p
604+
# always have AdjointPlan inside ScaledPlan.
605+
Base.adjoint(p::ScaledPlan{T}) where {T} = ScaledPlan{T}(p.p', p.scale)
606+
607+
size(p::AdjointPlan) = output_size(p.p)
608+
output_size(p::AdjointPlan) = size(p.p)
609+
610+
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
611+
612+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
613+
dims = region(p.p)
614+
N = normalization(T, size(p.p), dims)
615+
return 1/N * (p.p \ x)
616+
end
617+
618+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T}
619+
dims = region(p.p)
620+
N = normalization(T, size(p.p), dims)
621+
halfdim = first(dims)
622+
d = size(p.p, halfdim)
623+
n = output_size(p.p, halfdim)
624+
scale = reshape(
625+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
626+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
627+
)
628+
return 1/N * (p.p \ (x ./ scale))
629+
end
630+
631+
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
632+
dims = region(p.p)
633+
N = normalization(real(T), output_size(p.p), dims)
634+
halfdim = first(dims)
635+
n = size(p.p, halfdim)
636+
d = output_size(p.p, halfdim)
637+
scale = reshape(
638+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
639+
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
640+
)
641+
return 1/N * scale .* (p.p \ x)
642+
end
643+
644+
plan_inv(p::AdjointPlan) = AdjointPlan(plan_inv(p.p))

0 commit comments

Comments
 (0)