@@ -46,341 +46,5 @@ Base.adjoint(alg::Union{SVD,SDD,Polar}) = alg
4646const OFA = OrthogonalFactorizationAlgorithm
4747const SVDAlg = Union{SVD,SDD}
4848
49- # Matrix algebra: entrypoint for calling matrix methods from within tensor implementations
50- # ------------------------------------------------------------------------------------------
51- module MatrixAlgebra
52- # TODO : all methods tha twe define here will need an extended version for CuMatrix in the
53- # CUDA package extension.
54-
55- # TODO : other methods to include here:
56- # mul! (possibly call matmul! instead)
57- # adjoint!
58- # sylvester
59- # exp!
60- # schur!?
61- #
62-
63- using LinearAlgebra
64- using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare
65-
66- using .. TensorKit: OrthogonalFactorizationAlgorithm,
67- QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar
68-
69- # only defined in >v1.7
70- @static if VERSION < v " 1.7-"
71- _rf_findmax ((fm, im), (fx, ix)) = isless (fm, fx) ? (fx, ix) : (fm, im)
72- _argmax (f, domain) = mapfoldl (x -> (f (x), x), _rf_findmax, domain)[2 ]
73- else
74- _argmax (f, domain) = argmax (f, domain)
75- end
76-
77- # TODO : define for CuMatrix if we support this
78- function one! (A:: StridedMatrix )
79- length (A) > 0 || return A
80- copyto! (A, LinearAlgebra. I)
81- return A
82- end
83-
8449safesign (s:: Real ) = ifelse (s < zero (s), - one (s), + one (s))
8550safesign (s:: Complex ) = ifelse (iszero (s), one (s), s / abs (s))
86-
87- function leftorth! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{QR,QRpos} , atol:: Real )
88- iszero (atol) || throw (ArgumentError (" nonzero atol not supported by $alg " ))
89- m, n = size (A)
90- k = min (m, n)
91- A, T = LAPACK. geqrt! (A, min (minimum (size (A)), 36 ))
92- Q = similar (A, m, k)
93- for j in 1 : k
94- for i in 1 : m
95- Q[i, j] = i == j
96- end
97- end
98- Q = LAPACK. gemqrt! (' L' , ' N' , A, T, Q)
99- R = triu! (A[1 : k, :])
100-
101- if isa (alg, QRpos)
102- @inbounds for j in 1 : k
103- s = safesign (R[j, j])
104- @simd for i in 1 : m
105- Q[i, j] *= s
106- end
107- end
108- @inbounds for j in size (R, 2 ): - 1 : 1
109- for i in 1 : min (k, j)
110- R[i, j] = R[i, j] * conj (safesign (R[i, i]))
111- end
112- end
113- end
114- return Q, R
115- end
116-
117- function leftorth! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{QL,QLpos} , atol:: Real )
118- iszero (atol) || throw (ArgumentError (" nonzero atol not supported by $alg " ))
119- m, n = size (A)
120- @assert m >= n
121-
122- nhalf = div (n, 2 )
123- # swap columns in A
124- @inbounds for j in 1 : nhalf, i in 1 : m
125- A[i, j], A[i, n + 1 - j] = A[i, n + 1 - j], A[i, j]
126- end
127- Q, R = leftorth! (A, isa (alg, QL) ? QR () : QRpos (), atol)
128-
129- # swap columns in Q
130- @inbounds for j in 1 : nhalf, i in 1 : m
131- Q[i, j], Q[i, n + 1 - j] = Q[i, n + 1 - j], Q[i, j]
132- end
133- # swap rows and columns in R
134- @inbounds for j in 1 : nhalf, i in 1 : n
135- R[i, j], R[n + 1 - i, n + 1 - j] = R[n + 1 - i, n + 1 - j], R[i, j]
136- end
137- if isodd (n)
138- j = nhalf + 1
139- @inbounds for i in 1 : nhalf
140- R[i, j], R[n + 1 - i, j] = R[n + 1 - i, j], R[i, j]
141- end
142- end
143- return Q, R
144- end
145-
146- function leftorth! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{SVD,SDD,Polar} , atol:: Real )
147- U, S, V = alg isa SVD ? LAPACK. gesvd! (' S' , ' S' , A) : LAPACK. gesdd! (' S' , A)
148- if isa (alg, Union{SVD,SDD})
149- n = count (s -> s .> atol, S)
150- if n != length (S)
151- return U[:, 1 : n], lmul! (Diagonal (S[1 : n]), V[1 : n, :])
152- else
153- return U, lmul! (Diagonal (S), V)
154- end
155- else
156- iszero (atol) || throw (ArgumentError (" nonzero atol not supported by $alg " ))
157- # TODO : check Lapack to see if we can recycle memory of A
158- Q = mul! (A, U, V)
159- Sq = map! (sqrt, S, S)
160- SqV = lmul! (Diagonal (Sq), V)
161- R = SqV' * SqV
162- return Q, R
163- end
164- end
165-
166- function leftnull! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{QR,QRpos} , atol:: Real )
167- iszero (atol) || throw (ArgumentError (" nonzero atol not supported by $alg " ))
168- m, n = size (A)
169- m >= n || throw (ArgumentError (" no null space if less rows than columns" ))
170-
171- A, T = LAPACK. geqrt! (A, min (minimum (size (A)), 36 ))
172- N = similar (A, m, max (0 , m - n))
173- fill! (N, 0 )
174- for k in 1 : (m - n)
175- N[n + k, k] = 1
176- end
177- return N = LAPACK. gemqrt! (' L' , ' N' , A, T, N)
178- end
179-
180- function leftnull! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{SVD,SDD} , atol:: Real )
181- size (A, 2 ) == 0 && return one! (similar (A, (size (A, 1 ), size (A, 1 ))))
182- U, S, V = alg isa SVD ? LAPACK. gesvd! (' A' , ' N' , A) : LAPACK. gesdd! (' A' , A)
183- indstart = count (> (atol), S) + 1
184- return U[:, indstart: end ]
185- end
186-
187- function rightorth! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{LQ,LQpos,RQ,RQpos} ,
188- atol:: Real )
189- iszero (atol) || throw (ArgumentError (" nonzero atol not supported by $alg " ))
190- # TODO : geqrfp seems a bit slower than geqrt in the intermediate region around
191- # matrix size 100, which is the interesting region. => Investigate and fix
192- m, n = size (A)
193- k = min (m, n)
194- At = transpose! (similar (A, n, m), A)
195-
196- if isa (alg, RQ) || isa (alg, RQpos)
197- @assert m <= n
198-
199- mhalf = div (m, 2 )
200- # swap columns in At
201- @inbounds for j in 1 : mhalf, i in 1 : n
202- At[i, j], At[i, m + 1 - j] = At[i, m + 1 - j], At[i, j]
203- end
204- Qt, Rt = leftorth! (At, isa (alg, RQ) ? QR () : QRpos (), atol)
205-
206- @inbounds for j in 1 : mhalf, i in 1 : n
207- Qt[i, j], Qt[i, m + 1 - j] = Qt[i, m + 1 - j], Qt[i, j]
208- end
209- @inbounds for j in 1 : mhalf, i in 1 : m
210- Rt[i, j], Rt[m + 1 - i, m + 1 - j] = Rt[m + 1 - i, m + 1 - j], Rt[i, j]
211- end
212- if isodd (m)
213- j = mhalf + 1
214- @inbounds for i in 1 : mhalf
215- Rt[i, j], Rt[m + 1 - i, j] = Rt[m + 1 - i, j], Rt[i, j]
216- end
217- end
218- Q = transpose! (A, Qt)
219- R = transpose! (similar (A, (m, m)), Rt) # TODO : efficient in place
220- return R, Q
221- else
222- Qt, Lt = leftorth! (At, alg' , atol)
223- if m > n
224- L = transpose! (A, Lt)
225- Q = transpose! (similar (A, (n, n)), Qt) # TODO : efficient in place
226- else
227- Q = transpose! (A, Qt)
228- L = transpose! (similar (A, (m, m)), Lt) # TODO : efficient in place
229- end
230- return L, Q
231- end
232- end
233-
234- function rightorth! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{SVD,SDD,Polar} , atol:: Real )
235- U, S, V = alg isa SVD ? LAPACK. gesvd! (' S' , ' S' , A) : LAPACK. gesdd! (' S' , A)
236- if isa (alg, Union{SVD,SDD})
237- n = count (s -> s .> atol, S)
238- if n != length (S)
239- return rmul! (U[:, 1 : n], Diagonal (S[1 : n])), V[1 : n, :]
240- else
241- return rmul! (U, Diagonal (S)), V
242- end
243- else
244- iszero (atol) || throw (ArgumentError (" nonzero atol not supported by $alg " ))
245- Q = mul! (A, U, V)
246- Sq = map! (sqrt, S, S)
247- USq = rmul! (U, Diagonal (Sq))
248- L = USq * USq'
249- return L, Q
250- end
251- end
252-
253- function rightnull! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{LQ,LQpos} , atol:: Real )
254- iszero (atol) || throw (ArgumentError (" nonzero atol not supported by $alg " ))
255- m, n = size (A)
256- k = min (m, n)
257- At = adjoint! (similar (A, n, m), A)
258- At, T = LAPACK. geqrt! (At, min (k, 36 ))
259- N = similar (A, max (n - m, 0 ), n)
260- fill! (N, 0 )
261- for k in 1 : (n - m)
262- N[k, m + k] = 1
263- end
264- return N = LAPACK. gemqrt! (' R' , eltype (At) <: Real ? ' T' : ' C' , At, T, N)
265- end
266-
267- function rightnull! (A:: StridedMatrix{<:BlasFloat} , alg:: Union{SVD,SDD} , atol:: Real )
268- size (A, 1 ) == 0 && return one! (similar (A, (size (A, 2 ), size (A, 2 ))))
269- U, S, V = alg isa SVD ? LAPACK. gesvd! (' N' , ' A' , A) : LAPACK. gesdd! (' A' , A)
270- indstart = count (> (atol), S) + 1
271- return V[indstart: end , :]
272- end
273-
274- function svd! (A:: StridedMatrix{T} , alg:: Union{SVD,SDD} ) where {T<: BlasFloat }
275- # fix another type instability in LAPACK wrappers
276- TT = Tuple{Matrix{T},Vector{real (T)},Matrix{T}}
277- U, S, V = alg isa SVD ? LAPACK. gesvd! (' S' , ' S' , A):: TT : LAPACK. gesdd! (' S' , A):: TT
278- return U, S, V
279- end
280-
281- function eig! (A:: StridedMatrix{T} ; permute:: Bool = true , scale:: Bool = true ) where {T<: BlasReal }
282- n = checksquare (A)
283- n == 0 && return zeros (Complex{T}, 0 ), zeros (Complex{T}, 0 , 0 )
284-
285- A, DR, DI, VL, VR, _ = LAPACK. geevx! (permute ? (scale ? ' B' : ' P' ) :
286- (scale ? ' S' : ' N' ), ' N' , ' V' , ' N' , A)
287- D = complex .(DR, DI)
288- V = zeros (Complex{T}, n, n)
289- j = 1
290- while j <= n
291- if DI[j] == 0
292- vr = view (VR, :, j)
293- s = conj (sign (_argmax (abs, vr)))
294- V[:, j] .= s .* vr
295- else
296- vr = view (VR, :, j)
297- vi = view (VR, :, j + 1 )
298- s = conj (sign (_argmax (abs, vr))) # vectors coming from lapack have already real absmax component
299- V[:, j] .= s .* (vr .+ im .* vi)
300- V[:, j + 1 ] .= s .* (vr .- im .* vi)
301- j += 1
302- end
303- j += 1
304- end
305- return D, V
306- end
307-
308- function eig! (A:: StridedMatrix{T} ; permute:: Bool = true ,
309- scale:: Bool = true ) where {T<: BlasComplex }
310- n = checksquare (A)
311- n == 0 && return zeros (T, 0 ), zeros (T, 0 , 0 )
312- D, V = LAPACK. geevx! (permute ? (scale ? ' B' : ' P' ) : (scale ? ' S' : ' N' ), ' N' , ' V' , ' N' ,
313- A)[[2 , 4 ]]
314- for j in 1 : n
315- v = view (V, :, j)
316- s = conj (sign (_argmax (abs, v)))
317- v .*= s
318- end
319- return D, V
320- end
321-
322- function eigh! (A:: StridedMatrix{T} ) where {T<: BlasFloat }
323- n = checksquare (A)
324- n == 0 && return zeros (real (T), 0 ), zeros (T, 0 , 0 )
325- D, V = LAPACK. syevr! (' V' , ' A' , ' U' , A, 0.0 , 0.0 , 0 , 0 , - 1.0 )
326- for j in 1 : n
327- v = view (V, :, j)
328- s = conj (sign (_argmax (abs, v)))
329- v .*= s
330- end
331- return D, V
332- end
333-
334- # # Old stuff and experiments
335-
336- # using LinearAlgebra: BlasFloat, Char, BlasInt, LAPACK, LAPACKException,
337- # DimensionMismatch, SingularException, PosDefException, chkstride1,
338- # checksquare,
339- # triu!
340-
341- # TODO : reconsider the following implementation
342- # Unfortunately, geqrfp seems a bit slower than geqrt in the intermediate region
343- # around matrix size 100, which is the interesting region. => Investigate and maybe fix
344- # function _leftorth!(A::StridedMatrix{<:BlasFloat})
345- # m, n = size(A)
346- # A, τ = geqrfp!(A)
347- # Q = LAPACK.ormqr!('L', 'N', A, τ, eye(eltype(A), m, min(m, n)))
348- # R = triu!(A[1:min(m, n), :])
349- # return Q, R
350- # end
351-
352- # geqrfp!: computes qrpos factorization, missing in Base
353- # geqrfp!(A::StridedMatrix{<:BlasFloat}) =
354- # ((m, n) = size(A); geqrfp!(A, similar(A, min(m, n))))
355- #
356- # for (geqrfp, elty, relty) in
357- # ((:dgeqrfp_, :Float64, :Float64), (:sgeqrfp_, :Float32, :Float32),
358- # (:zgeqrfp_, :ComplexF64, :Float64), (:cgeqrfp_, :ComplexF32, :Float32))
359- # @eval begin
360- # function geqrfp!(A::StridedMatrix{$elty}, tau::StridedVector{$elty})
361- # chkstride1(A, tau)
362- # m, n = size(A)
363- # if length(tau) != min(m, n)
364- # throw(DimensionMismatch("tau has length $(length(tau)), but needs length $(min(m, n))"))
365- # end
366- # work = Vector{$elty}(1)
367- # lwork = BlasInt(-1)
368- # info = Ref{BlasInt}()
369- # for i = 1:2 # first call returns lwork as work[1]
370- # ccall((@blasfunc($geqrfp), liblapack), Nothing,
371- # (Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
372- # Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}),
373- # Ref(m), Ref(n), A, Ref(max(1, stride(A, 2))),
374- # tau, work, Ref(lwork), info)
375- # chklapackerror(info[])
376- # if i == 1
377- # lwork = BlasInt(real(work[1]))
378- # resize!(work, lwork)
379- # end
380- # end
381- # A, tau
382- # end
383- # end
384- # end
385-
386- end
0 commit comments