Skip to content

Commit 497ff4d

Browse files
Apply suggestions from adjoint plan code review
Co-authored-by: David Widmann <[email protected]>
1 parent 061eef9 commit 497ff4d

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/definitions.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -593,16 +593,16 @@ _output_size(p::Plan, ::NoProjectionStyle) = size(p)
593593
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p))
594594
_output_size(p::Plan, ::RealInverseProjectionStyle) = brfft_output_size(size(p), irfft_dim(p), region(p))
595595

596-
mutable struct AdjointPlan{T,P} <: Plan{T}
596+
mutable struct AdjointPlan{T,P<:Plan} <: Plan{T}
597597
p::P
598598
pinv::Plan
599599
AdjointPlan{T,P}(p) where {T,P} = new(p)
600600
end
601601

602602
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
603-
Base.adjoint(p::AdjointPlan{T}) where {T} = p.p
603+
Base.adjoint(p::AdjointPlan) = p.p
604604
# always have AdjointPlan inside ScaledPlan.
605-
Base.adjoint(p::ScaledPlan{T}) where {T} = ScaledPlan{T}(p.p', p.scale)
605+
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
606606

607607
size(p::AdjointPlan) = output_size(p.p)
608608
output_size(p::AdjointPlan) = size(p.p)
@@ -612,7 +612,7 @@ Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
612612
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
613613
dims = region(p.p)
614614
N = normalization(T, size(p.p), dims)
615-
return 1/N * (p.p \ x)
615+
return (p.p \ x) / N
616616
end
617617

618618
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T}
@@ -622,10 +622,10 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where
622622
d = size(p.p, halfdim)
623623
n = output_size(p.p, halfdim)
624624
scale = reshape(
625-
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
626-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
625+
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
626+
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
627627
)
628-
return 1/N * (p.p \ (x ./ scale))
628+
return p.p \ (x ./ scale)
629629
end
630630

631631
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
@@ -636,9 +636,9 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle)
636636
d = output_size(p.p, halfdim)
637637
scale = reshape(
638638
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
639-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x)))
639+
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
640640
)
641-
return 1/N * scale .* (p.p \ x)
641+
return scale ./ N .* (p.p \ x)
642642
end
643643

644-
plan_inv(p::AdjointPlan) = AdjointPlan(plan_inv(p.p))
644+
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))

0 commit comments

Comments
 (0)