@@ -70,37 +70,43 @@ function _subtractone!(a::AbstractMatrix)
7070 view (a, diagind (a)) .= view (a, diagind (a)) .- 1
7171 return a
7272end
73- function _polarsdd! (A :: StridedMatrix )
74- U, S, V = svd! (A; alg = LinearAlgebra . DivideAndConquer ())
75- return mul! (A, U, V ' )
76- end
77- function _polarsvd! (A :: StridedMatrix )
78- U, S, V = svd ! (A; alg= LinearAlgebra . QRIteration () )
79- return mul! (A, U, V ' )
73+
74+ # TODO : _left_polar! is more or less the same as MAK.left_polar! but doesn't compute the P
75+ # which is not needed here. Can we unify this?
76+ function _left_polar! (A :: StridedMatrix ,
77+ alg :: PolarViaSVD = PolarViaSVD ( LAPACK_DivideAndConquer ()) )
78+ U, _, Vᴴ = svd_compact ! (A, alg. svdalg )
79+ return mul! (A, U, Vᴴ )
8080end
81+
82+ # TODO : can we move this to a dedicated MAK algorithm?
83+ MatrixAlgebraKit. @algdef PolarNewton
84+
85+ _left_polar! (A:: StridedMatrix , alg:: PolarNewton ) = _polarnewton! (A; alg. kwargs... )
8186function _polarnewton! (A:: StridedMatrix ; tol= 10 * scalareps (A), maxiter= 5 )
8287 m, n = size (A)
8388 @assert m >= n
8489 A2 = copy (A)
85- Q, R = qr! (A2)
86- Ri = ldiv! (UpperTriangular (R)' , TensorKit . MatrixAlgebra . one! (similar (R)))
90+ Q, R = LinearAlgebra . qr! (A2)
91+ Ri = ldiv! (UpperTriangular (R)' , MatrixAlgebraKit . one! (similar (R)))
8792 R, Ri = _avgdiff! (R, Ri)
8893 i = 1
8994 R2 = view (A, 1 : n, 1 : n)
9095 fill! (view (A, (n + 1 ): m, 1 : n), zero (eltype (A)))
9196 copyto! (R2, R)
9297 while maximum (abs, Ri) > tol
9398 if i == maxiter # if not converged by now, fall back to sdd
94- _polarsdd ! (Ri)
99+ _left_polar ! (Ri)
95100 break
96101 end
97- Ri = ldiv! (lu! (R2)' , TensorKit . MatrixAlgebra . one! (Ri))
102+ Ri = ldiv! (lu! (R2)' , MatrixAlgebraKit . one! (Ri))
98103 R, Ri = _avgdiff! (R, Ri)
99104 copyto! (R2, R)
100105 i += 1
101106 end
102107 return lmul! (Q, A)
103108end
109+
104110# in place computation of the average and difference of two arrays
105111function _avgdiff! (A:: AbstractArray , B:: AbstractArray )
106112 axes (A) == axes (B) || throw (DimensionMismatch ())
124130function _stiefelexp (W:: StridedMatrix , A:: StridedMatrix , Z:: StridedMatrix , α)
125131 n, p = size (W)
126132 r = min (2 * p, n)
127- QQ, _ = qr! ([W Z])
133+ QQ, _ = LinearAlgebra . qr! ([W Z])
128134 Q = similar (W, n, r - p)
129135 @inbounds for j in Base. OneTo (r - p)
130136 for i in Base. OneTo (n)
@@ -139,7 +145,7 @@ function _stiefelexp(W::StridedMatrix, A::StridedMatrix, Z::StridedMatrix, α)
139145 A2[1 : p, (p + 1 ): end ] .= (- α) .* (R' )
140146 A2[(p + 1 ): end , (p + 1 ): end ] .= 0
141147 U = [W Q] * exp (A2)
142- U = _polarnewton ! (U)
148+ U = _left_polar ! (U, PolarNewton () )
143149 W′ = U[:, 1 : p]
144150 Q′ = U[:, (p + 1 ): end ]
145151 R′ = R
@@ -152,7 +158,7 @@ function _stiefellog(Wold::StridedMatrix, Wnew::StridedMatrix;
152158 r = min (2 * p, n)
153159 P = Wold' * Wnew
154160 dW = Wnew - Wold * P
155- QQ, _ = qr! ([Wold dW])
161+ QQ, _ = LinearAlgebra . qr! ([Wold dW])
156162 Q = similar (Wold, n, r - p)
157163 @inbounds for j in Base. OneTo (r - p)
158164 for i in Base. OneTo (n)
@@ -161,23 +167,17 @@ function _stiefellog(Wold::StridedMatrix, Wnew::StridedMatrix;
161167 end
162168 Q = lmul! (QQ, Q)
163169 R = Q' * dW
164- Wext = [Wold Q]
165- F = qr! ([P; R])
166- U = lmul! (F. Q, TensorKit. MatrixAlgebra. one! (similar (P, r, r)))
170+ F = LinearAlgebra. qr! ([P; R])
171+ U = lmul! (F. Q, MatrixAlgebraKit. one! (similar (P, r, r)))
167172 U[1 : p, 1 : p] .= P
168173 U[(p + 1 ): r, 1 : p] .= R
169174 X = view (U, 1 : p, (p + 1 ): r)
170175 Y = view (U, (p + 1 ): r, (p + 1 ): r)
171176 if p < n
172- YSVD = svd! (Y)
173- mul! (X, X * (YSVD. V), (YSVD. U)' )
174- UsqrtS = YSVD. U
175- @inbounds for j in 1 : size (UsqrtS, 2 )
176- s = sqrt (YSVD. S[j])
177- @simd for i in 1 : size (UsqrtS, 1 )
178- UsqrtS[i, j] *= s
179- end
180- end
177+ USVᴴ = svd_compact! (Y)
178+ mul! (X, X * USVᴴ[3 ]' , USVᴴ[1 ]' )
179+ diagview (USVᴴ[2 ]) .= sqrt .(diagview (USVᴴ[2 ]))
180+ UsqrtS = rmul! (USVᴴ[1 ], USVᴴ[2 ])
181181 mul! (Y, UsqrtS, UsqrtS' )
182182 end
183183 logU = _projectantihermitian! (log (U))
0 commit comments