@@ -11,7 +11,7 @@ function check_input(::typeof(left_polar!), A::AbstractMatrix, WP, ::AbstractAlg
1111 @assert W isa AbstractMatrix && P isa AbstractMatrix
1212 @check_size(W, (m, n))
1313 @check_scalar(W, A)
14- @check_size(P, (n, n))
14+ isempty(P) || @check_size(P, (n, n))
1515 @check_scalar(P, A)
1616 return nothing
1717end
@@ -21,7 +21,7 @@ function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, ::Abstrac
2121 n >= m ||
2222 throw(ArgumentError(" input matrix needs at least as many columns as rows" ))
2323 @assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix
24- @check_size(P, (m, m))
24+ isempty(P) || @check_size(P, (m, m))
2525 @check_scalar(P, A)
2626 @check_size(Wᴴ, (m, n))
2727 @check_scalar(Wᴴ, A)
@@ -43,25 +43,154 @@ function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::Abstract
4343 return (P, Wᴴ)
4444end
4545
46- # Implementation
47- # --------------
46+ # Implementation via SVD
47+ # -----------------------
4848function left_polar!(A:: AbstractMatrix , WP, alg:: PolarViaSVD )
4949 check_input(left_polar!, A, WP, alg)
50- U, S, Vᴴ = svd_compact!(A, alg. svdalg )
50+ U, S, Vᴴ = svd_compact!(A, alg. svd_alg )
5151 W, P = WP
5252 W = mul!(W, U, Vᴴ)
53- S .= sqrt.(S)
54- SsqrtVᴴ = lmul!(S, Vᴴ)
55- P = mul!(P, SsqrtVᴴ' , SsqrtVᴴ)
53+ if ! isempty(P)
54+ S .= sqrt.(S)
55+ SsqrtVᴴ = lmul!(S, Vᴴ)
56+ P = mul!(P, SsqrtVᴴ' , SsqrtVᴴ)
57+ end
5658 return (W, P)
5759end
5860function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD)
5961 check_input(right_polar!, A, PWᴴ, alg)
60- U, S, Vᴴ = svd_compact!(A, alg.svdalg )
62+ U, S, Vᴴ = svd_compact!(A, alg.svd_alg )
6163 P, Wᴴ = PWᴴ
6264 Wᴴ = mul!(Wᴴ, U, Vᴴ)
63- S .= sqrt.(S)
64- USsqrt = rmul!(U, S)
65- P = mul!(P, USsqrt, USsqrt' )
65+ if !isempty(P)
66+ S .= sqrt.(S)
67+ USsqrt = rmul!(U, S)
68+ P = mul!(P, USsqrt, USsqrt' )
69+ end
6670 return (P, Wᴴ)
6771end
72+
73+ # Implementation via Newton
74+ # --------------------------
75+ function left_polar!(A:: AbstractMatrix , WP, alg:: PolarNewton )
76+ check_input(left_polar!, A, WP, alg)
77+ W, P = WP
78+ if isempty(P)
79+ W = _left_polarnewton!(A, W, P; alg. kwargs... )
80+ return W, P
81+ else
82+ W = _left_polarnewton!(copy(A), W, P; alg. kwargs... )
83+ # we still need `A` to compute `P`
84+ P = project_hermitian!(mul!(P, W' , A))
85+ return W, P
86+ end
87+ end
88+
89+ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarNewton)
90+ check_input(right_polar!, A, PWᴴ, alg)
91+ P, Wᴴ = PWᴴ
92+ if isempty(P)
93+ Wᴴ = _right_polarnewton!(A, Wᴴ, P; alg.kwargs...)
94+ return P, Wᴴ
95+ else
96+ Wᴴ = _right_polarnewton!(copy(A), Wᴴ, P; alg.kwargs...)
97+ # we still need `A` to compute `P`
98+ P = project_hermitian!(mul!(P, A, Wᴴ' ))
99+ return P, Wᴴ
100+ end
101+ end
102+
103+ # these methods only compute W and destroy A in the process
104+ function _left_polarnewton!(A:: AbstractMatrix , W, P = similar(A, (0 , 0 )); tol = defaulttol(A), maxiter = 10 )
105+ m, n = size(A) # we must have m >= n
106+ Rᴴinv = isempty(P) ? similar(P, (n, n)) : P # use P as workspace when available
107+ if m > n # initial QR
108+ Q, R = qr_compact!(A)
109+ Rc = view(A, 1 : n, 1 : n)
110+ copy!(Rc, R)
111+ Rᴴinv = ldiv!(UpperTriangular(Rc)' , one!(Rᴴinv))
112+ else # m == n
113+ R = A
114+ Rc = view(W, 1 : n, 1 : n)
115+ copy!(Rc, R)
116+ Rᴴinv = ldiv!(lu!(Rc)' , one!(Rᴴinv))
117+ end
118+ γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
119+ rmul!(R, γ)
120+ rmul!(Rᴴinv, 1 / γ)
121+ R, Rᴴinv = _avgdiff!(R, Rᴴinv)
122+ copy!(Rc, R)
123+ i = 1
124+ conv = norm(Rᴴinv, Inf )
125+ while i < maxiter && conv > tol
126+ Rᴴinv = ldiv!(lu!(Rc)' , one!(Rᴴinv))
127+ γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
128+ rmul!(R, γ)
129+ rmul!(Rᴴinv, 1 / γ)
130+ R, Rᴴinv = _avgdiff!(R, Rᴴinv)
131+ copy!(Rc, R)
132+ conv = norm(Rᴴinv, Inf )
133+ i += 1
134+ end
135+ if conv > tol
136+ @warn " `left_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv )"
137+ end
138+ if m > n
139+ return mul!(W, Q, Rc)
140+ end
141+ return W
142+ end
143+
144+ function _right_polarnewton!(A:: AbstractMatrix , Wᴴ, P = similar(A, (0 , 0 )); tol = defaulttol(A), maxiter = 10 )
145+ m, n = size(A) # we must have m <= n
146+ Lᴴinv = isempty(P) ? similar(P, (m, m)) : P # use P as workspace when available
147+ if m < n # initial QR
148+ L, Q = lq_compact!(A)
149+ Lc = view(A, 1 : m, 1 : m)
150+ copy!(Lc, L)
151+ Lᴴinv = ldiv!(LowerTriangular(Lc)' , one!(Lᴴinv))
152+ else # m == n
153+ L = A
154+ Lc = view(Wᴴ, 1 : m, 1 : m)
155+ copy!(Lc, L)
156+ Lᴴinv = ldiv!(lu!(Lc)' , one!(Lᴴinv))
157+ end
158+ γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
159+ rmul!(L, γ)
160+ rmul!(Lᴴinv, 1 / γ)
161+ L, Lᴴinv = _avgdiff!(L, Lᴴinv)
162+ copy!(Lc, L)
163+ i = 1
164+ conv = norm(Lᴴinv, Inf )
165+ while i < maxiter && conv > tol
166+ Lᴴinv = ldiv!(lu!(Lc)' , one!(Lᴴinv))
167+ γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
168+ rmul!(L, γ)
169+ rmul!(Lᴴinv, 1 / γ)
170+ L, Lᴴinv = _avgdiff!(L, Lᴴinv)
171+ copy!(Lc, L)
172+ conv = norm(Lᴴinv, Inf )
173+ i += 1
174+ end
175+ if conv > tol
176+ @warn " `right_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv )"
177+ end
178+ if m < n
179+ return mul!(Wᴴ, Lc, Q)
180+ end
181+ return Wᴴ
182+ end
183+
184+ # in place computation of the average and difference of two arrays
185+ function _avgdiff!(A:: AbstractArray , B:: AbstractArray )
186+ axes(A) == axes(B) || throw(DimensionMismatch())
187+ @simd for I in eachindex(A, B)
188+ @inbounds begin
189+ a = A[I]
190+ b = B[I]
191+ A[I] = (a + b) / 2
192+ B[I] = b - a
193+ end
194+ end
195+ return A, B
196+ end
0 commit comments