@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T
12
12
13
13
# size(p) should return the size of the input array for p
14
14
size (p:: Plan , d) = size (p)[d]
15
+ output_size (p:: Plan , d) = output_size (p)[d]
15
16
ndims (p:: Plan ) = length (size (p))
16
17
length (p:: Plan ) = prod (size (p)):: Int
17
18
@@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
255
256
ScaledPlan (p:: ScaledPlan , α:: Number ) = ScaledPlan (p. p, p. scale * α)
256
257
257
258
size (p:: ScaledPlan ) = size (p. p)
259
+ output_size (p:: ScaledPlan ) = output_size (p. p)
258
260
259
261
fftdims (p:: ScaledPlan ) = fftdims (p. p)
260
262
@@ -576,3 +578,67 @@ Pre-plan an optimized real-input unnormalized transform, similar to
576
578
the same as for [`brfft`](@ref).
577
579
"""
578
580
plan_brfft
581
+
582
+ # #############################################################################
583
+
584
+ struct NoProjectionStyle end
585
+ struct RealProjectionStyle end
586
+ struct RealInverseProjectionStyle end
587
+ const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
588
+
589
+ function irfft_dim end
590
+
591
+ output_size (p:: Plan ) = _output_size (p, ProjectionStyle (p))
592
+ _output_size (p:: Plan , :: NoProjectionStyle ) = size (p)
593
+ _output_size (p:: Plan , :: RealProjectionStyle ) = rfft_output_size (size (p), region (p))
594
+ _output_size (p:: Plan , :: RealInverseProjectionStyle ) = brfft_output_size (size (p), irfft_dim (p), region (p))
595
+
596
+ mutable struct AdjointPlan{T,P} <: Plan{T}
597
+ p:: P
598
+ pinv:: Plan
599
+ AdjointPlan {T,P} (p) where {T,P} = new (p)
600
+ end
601
+
602
+ Base. adjoint (p:: Plan{T} ) where {T} = AdjointPlan {T, typeof(p)} (p)
603
+ Base. adjoint (p:: AdjointPlan{T} ) where {T} = p. p
604
+ # always have AdjointPlan inside ScaledPlan.
605
+ Base. adjoint (p:: ScaledPlan{T} ) where {T} = ScaledPlan {T} (p. p' , p. scale)
606
+
607
+ size (p:: AdjointPlan ) = output_size (p. p)
608
+ output_size (p:: AdjointPlan ) = size (p. p)
609
+
610
+ Base.:* (p:: AdjointPlan , x:: AbstractArray ) = _mul (p, x, ProjectionStyle (p. p))
611
+
612
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: NoProjectionStyle ) where {T}
613
+ dims = region (p. p)
614
+ N = normalization (T, size (p. p), dims)
615
+ return 1 / N * (p. p \ x)
616
+ end
617
+
618
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: RealProjectionStyle ) where {T}
619
+ dims = region (p. p)
620
+ N = normalization (T, size (p. p), dims)
621
+ halfdim = first (dims)
622
+ d = size (p. p, halfdim)
623
+ n = output_size (p. p, halfdim)
624
+ scale = reshape (
625
+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
626
+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
627
+ )
628
+ return 1 / N * (p. p \ (x ./ scale))
629
+ end
630
+
631
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: RealInverseProjectionStyle ) where {T}
632
+ dims = region (p. p)
633
+ N = normalization (real (T), output_size (p. p), dims)
634
+ halfdim = first (dims)
635
+ n = size (p. p, halfdim)
636
+ d = output_size (p. p, halfdim)
637
+ scale = reshape (
638
+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
639
+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
640
+ )
641
+ return 1 / N * scale .* (p. p \ x)
642
+ end
643
+
644
+ plan_inv (p:: AdjointPlan ) = AdjointPlan (plan_inv (p. p))
0 commit comments