Skip to content

Commit 113f9a1

Browse files
authored
Merge pull request #272 from kshyatt/ksh/surgery
2 parents 8549b2a + 59bdfd8 commit 113f9a1

File tree

7 files changed

+17
-351
lines changed

7 files changed

+17
-351
lines changed

src/auxiliary/linalg.jl

Lines changed: 0 additions & 336 deletions
Original file line numberDiff line numberDiff line change
@@ -46,341 +46,5 @@ Base.adjoint(alg::Union{SVD,SDD,Polar}) = alg
4646
const OFA = OrthogonalFactorizationAlgorithm
4747
const SVDAlg = Union{SVD,SDD}
4848

49-
# Matrix algebra: entrypoint for calling matrix methods from within tensor implementations
50-
#------------------------------------------------------------------------------------------
51-
module MatrixAlgebra
52-
# TODO: all methods tha twe define here will need an extended version for CuMatrix in the
53-
# CUDA package extension.
54-
55-
# TODO: other methods to include here:
56-
# mul! (possibly call matmul! instead)
57-
# adjoint!
58-
# sylvester
59-
# exp!
60-
# schur!?
61-
#
62-
63-
using LinearAlgebra
64-
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare
65-
66-
using ..TensorKit: OrthogonalFactorizationAlgorithm,
67-
QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar
68-
69-
# only defined in >v1.7
70-
@static if VERSION < v"1.7-"
71-
_rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im)
72-
_argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2]
73-
else
74-
_argmax(f, domain) = argmax(f, domain)
75-
end
76-
77-
# TODO: define for CuMatrix if we support this
78-
function one!(A::StridedMatrix)
79-
length(A) > 0 || return A
80-
copyto!(A, LinearAlgebra.I)
81-
return A
82-
end
83-
8449
safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s))
8550
safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))
86-
87-
function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real)
88-
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
89-
m, n = size(A)
90-
k = min(m, n)
91-
A, T = LAPACK.geqrt!(A, min(minimum(size(A)), 36))
92-
Q = similar(A, m, k)
93-
for j in 1:k
94-
for i in 1:m
95-
Q[i, j] = i == j
96-
end
97-
end
98-
Q = LAPACK.gemqrt!('L', 'N', A, T, Q)
99-
R = triu!(A[1:k, :])
100-
101-
if isa(alg, QRpos)
102-
@inbounds for j in 1:k
103-
s = safesign(R[j, j])
104-
@simd for i in 1:m
105-
Q[i, j] *= s
106-
end
107-
end
108-
@inbounds for j in size(R, 2):-1:1
109-
for i in 1:min(k, j)
110-
R[i, j] = R[i, j] * conj(safesign(R[i, i]))
111-
end
112-
end
113-
end
114-
return Q, R
115-
end
116-
117-
function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QL,QLpos}, atol::Real)
118-
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
119-
m, n = size(A)
120-
@assert m >= n
121-
122-
nhalf = div(n, 2)
123-
#swap columns in A
124-
@inbounds for j in 1:nhalf, i in 1:m
125-
A[i, j], A[i, n + 1 - j] = A[i, n + 1 - j], A[i, j]
126-
end
127-
Q, R = leftorth!(A, isa(alg, QL) ? QR() : QRpos(), atol)
128-
129-
#swap columns in Q
130-
@inbounds for j in 1:nhalf, i in 1:m
131-
Q[i, j], Q[i, n + 1 - j] = Q[i, n + 1 - j], Q[i, j]
132-
end
133-
#swap rows and columns in R
134-
@inbounds for j in 1:nhalf, i in 1:n
135-
R[i, j], R[n + 1 - i, n + 1 - j] = R[n + 1 - i, n + 1 - j], R[i, j]
136-
end
137-
if isodd(n)
138-
j = nhalf + 1
139-
@inbounds for i in 1:nhalf
140-
R[i, j], R[n + 1 - i, j] = R[n + 1 - i, j], R[i, j]
141-
end
142-
end
143-
return Q, R
144-
end
145-
146-
function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, atol::Real)
147-
U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A) : LAPACK.gesdd!('S', A)
148-
if isa(alg, Union{SVD,SDD})
149-
n = count(s -> s .> atol, S)
150-
if n != length(S)
151-
return U[:, 1:n], lmul!(Diagonal(S[1:n]), V[1:n, :])
152-
else
153-
return U, lmul!(Diagonal(S), V)
154-
end
155-
else
156-
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
157-
# TODO: check Lapack to see if we can recycle memory of A
158-
Q = mul!(A, U, V)
159-
Sq = map!(sqrt, S, S)
160-
SqV = lmul!(Diagonal(Sq), V)
161-
R = SqV' * SqV
162-
return Q, R
163-
end
164-
end
165-
166-
function leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real)
167-
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
168-
m, n = size(A)
169-
m >= n || throw(ArgumentError("no null space if less rows than columns"))
170-
171-
A, T = LAPACK.geqrt!(A, min(minimum(size(A)), 36))
172-
N = similar(A, m, max(0, m - n))
173-
fill!(N, 0)
174-
for k in 1:(m - n)
175-
N[n + k, k] = 1
176-
end
177-
return N = LAPACK.gemqrt!('L', 'N', A, T, N)
178-
end
179-
180-
function leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}, atol::Real)
181-
size(A, 2) == 0 && return one!(similar(A, (size(A, 1), size(A, 1))))
182-
U, S, V = alg isa SVD ? LAPACK.gesvd!('A', 'N', A) : LAPACK.gesdd!('A', A)
183-
indstart = count(>(atol), S) + 1
184-
return U[:, indstart:end]
185-
end
186-
187-
function rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos,RQ,RQpos},
188-
atol::Real)
189-
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
190-
# TODO: geqrfp seems a bit slower than geqrt in the intermediate region around
191-
# matrix size 100, which is the interesting region. => Investigate and fix
192-
m, n = size(A)
193-
k = min(m, n)
194-
At = transpose!(similar(A, n, m), A)
195-
196-
if isa(alg, RQ) || isa(alg, RQpos)
197-
@assert m <= n
198-
199-
mhalf = div(m, 2)
200-
# swap columns in At
201-
@inbounds for j in 1:mhalf, i in 1:n
202-
At[i, j], At[i, m + 1 - j] = At[i, m + 1 - j], At[i, j]
203-
end
204-
Qt, Rt = leftorth!(At, isa(alg, RQ) ? QR() : QRpos(), atol)
205-
206-
@inbounds for j in 1:mhalf, i in 1:n
207-
Qt[i, j], Qt[i, m + 1 - j] = Qt[i, m + 1 - j], Qt[i, j]
208-
end
209-
@inbounds for j in 1:mhalf, i in 1:m
210-
Rt[i, j], Rt[m + 1 - i, m + 1 - j] = Rt[m + 1 - i, m + 1 - j], Rt[i, j]
211-
end
212-
if isodd(m)
213-
j = mhalf + 1
214-
@inbounds for i in 1:mhalf
215-
Rt[i, j], Rt[m + 1 - i, j] = Rt[m + 1 - i, j], Rt[i, j]
216-
end
217-
end
218-
Q = transpose!(A, Qt)
219-
R = transpose!(similar(A, (m, m)), Rt) # TODO: efficient in place
220-
return R, Q
221-
else
222-
Qt, Lt = leftorth!(At, alg', atol)
223-
if m > n
224-
L = transpose!(A, Lt)
225-
Q = transpose!(similar(A, (n, n)), Qt) # TODO: efficient in place
226-
else
227-
Q = transpose!(A, Qt)
228-
L = transpose!(similar(A, (m, m)), Lt) # TODO: efficient in place
229-
end
230-
return L, Q
231-
end
232-
end
233-
234-
function rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, atol::Real)
235-
U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A) : LAPACK.gesdd!('S', A)
236-
if isa(alg, Union{SVD,SDD})
237-
n = count(s -> s .> atol, S)
238-
if n != length(S)
239-
return rmul!(U[:, 1:n], Diagonal(S[1:n])), V[1:n, :]
240-
else
241-
return rmul!(U, Diagonal(S)), V
242-
end
243-
else
244-
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
245-
Q = mul!(A, U, V)
246-
Sq = map!(sqrt, S, S)
247-
USq = rmul!(U, Diagonal(Sq))
248-
L = USq * USq'
249-
return L, Q
250-
end
251-
end
252-
253-
function rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos}, atol::Real)
254-
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
255-
m, n = size(A)
256-
k = min(m, n)
257-
At = adjoint!(similar(A, n, m), A)
258-
At, T = LAPACK.geqrt!(At, min(k, 36))
259-
N = similar(A, max(n - m, 0), n)
260-
fill!(N, 0)
261-
for k in 1:(n - m)
262-
N[k, m + k] = 1
263-
end
264-
return N = LAPACK.gemqrt!('R', eltype(At) <: Real ? 'T' : 'C', At, T, N)
265-
end
266-
267-
function rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}, atol::Real)
268-
size(A, 1) == 0 && return one!(similar(A, (size(A, 2), size(A, 2))))
269-
U, S, V = alg isa SVD ? LAPACK.gesvd!('N', 'A', A) : LAPACK.gesdd!('A', A)
270-
indstart = count(>(atol), S) + 1
271-
return V[indstart:end, :]
272-
end
273-
274-
function svd!(A::StridedMatrix{T}, alg::Union{SVD,SDD}) where {T<:BlasFloat}
275-
# fix another type instability in LAPACK wrappers
276-
TT = Tuple{Matrix{T},Vector{real(T)},Matrix{T}}
277-
U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A)::TT : LAPACK.gesdd!('S', A)::TT
278-
return U, S, V
279-
end
280-
281-
function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where {T<:BlasReal}
282-
n = checksquare(A)
283-
n == 0 && return zeros(Complex{T}, 0), zeros(Complex{T}, 0, 0)
284-
285-
A, DR, DI, VL, VR, _ = LAPACK.geevx!(permute ? (scale ? 'B' : 'P') :
286-
(scale ? 'S' : 'N'), 'N', 'V', 'N', A)
287-
D = complex.(DR, DI)
288-
V = zeros(Complex{T}, n, n)
289-
j = 1
290-
while j <= n
291-
if DI[j] == 0
292-
vr = view(VR, :, j)
293-
s = conj(sign(_argmax(abs, vr)))
294-
V[:, j] .= s .* vr
295-
else
296-
vr = view(VR, :, j)
297-
vi = view(VR, :, j + 1)
298-
s = conj(sign(_argmax(abs, vr))) # vectors coming from lapack have already real absmax component
299-
V[:, j] .= s .* (vr .+ im .* vi)
300-
V[:, j + 1] .= s .* (vr .- im .* vi)
301-
j += 1
302-
end
303-
j += 1
304-
end
305-
return D, V
306-
end
307-
308-
function eig!(A::StridedMatrix{T}; permute::Bool=true,
309-
scale::Bool=true) where {T<:BlasComplex}
310-
n = checksquare(A)
311-
n == 0 && return zeros(T, 0), zeros(T, 0, 0)
312-
D, V = LAPACK.geevx!(permute ? (scale ? 'B' : 'P') : (scale ? 'S' : 'N'), 'N', 'V', 'N',
313-
A)[[2, 4]]
314-
for j in 1:n
315-
v = view(V, :, j)
316-
s = conj(sign(_argmax(abs, v)))
317-
v .*= s
318-
end
319-
return D, V
320-
end
321-
322-
function eigh!(A::StridedMatrix{T}) where {T<:BlasFloat}
323-
n = checksquare(A)
324-
n == 0 && return zeros(real(T), 0), zeros(T, 0, 0)
325-
D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0)
326-
for j in 1:n
327-
v = view(V, :, j)
328-
s = conj(sign(_argmax(abs, v)))
329-
v .*= s
330-
end
331-
return D, V
332-
end
333-
334-
## Old stuff and experiments
335-
336-
# using LinearAlgebra: BlasFloat, Char, BlasInt, LAPACK, LAPACKException,
337-
# DimensionMismatch, SingularException, PosDefException, chkstride1,
338-
# checksquare,
339-
# triu!
340-
341-
# TODO: reconsider the following implementation
342-
# Unfortunately, geqrfp seems a bit slower than geqrt in the intermediate region
343-
# around matrix size 100, which is the interesting region. => Investigate and maybe fix
344-
# function _leftorth!(A::StridedMatrix{<:BlasFloat})
345-
# m, n = size(A)
346-
# A, τ = geqrfp!(A)
347-
# Q = LAPACK.ormqr!('L', 'N', A, τ, eye(eltype(A), m, min(m, n)))
348-
# R = triu!(A[1:min(m, n), :])
349-
# return Q, R
350-
# end
351-
352-
# geqrfp!: computes qrpos factorization, missing in Base
353-
# geqrfp!(A::StridedMatrix{<:BlasFloat}) =
354-
# ((m, n) = size(A); geqrfp!(A, similar(A, min(m, n))))
355-
#
356-
# for (geqrfp, elty, relty) in
357-
# ((:dgeqrfp_, :Float64, :Float64), (:sgeqrfp_, :Float32, :Float32),
358-
# (:zgeqrfp_, :ComplexF64, :Float64), (:cgeqrfp_, :ComplexF32, :Float32))
359-
# @eval begin
360-
# function geqrfp!(A::StridedMatrix{$elty}, tau::StridedVector{$elty})
361-
# chkstride1(A, tau)
362-
# m, n = size(A)
363-
# if length(tau) != min(m, n)
364-
# throw(DimensionMismatch("tau has length $(length(tau)), but needs length $(min(m, n))"))
365-
# end
366-
# work = Vector{$elty}(1)
367-
# lwork = BlasInt(-1)
368-
# info = Ref{BlasInt}()
369-
# for i = 1:2 # first call returns lwork as work[1]
370-
# ccall((@blasfunc($geqrfp), liblapack), Nothing,
371-
# (Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
372-
# Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}),
373-
# Ref(m), Ref(n), A, Ref(max(1, stride(A, 2))),
374-
# tau, work, Ref(lwork), info)
375-
# chklapackerror(info[])
376-
# if i == 1
377-
# lwork = BlasInt(real(work[1]))
378-
# resize!(work, lwork)
379-
# end
380-
# end
381-
# A, tau
382-
# end
383-
# end
384-
# end
385-
386-
end

src/auxiliary/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ function randisometry!(rng::Random.AbstractRNG, A::AbstractMatrix)
2020
dims = size(A)
2121
dims[1] >= dims[2] ||
2222
throw(DimensionMismatch("cannot create isometric matrix with dimensions $dims; isometry needs to be tall or square"))
23-
Q, = MatrixAlgebra.leftorth!(Random.randn!(rng, A), QRpos(), 0)
23+
Q, = leftorth!(Random.randn!(rng, A); alg=QRpos())
2424
return copy!(A, Q)
2525
end

src/tensors/factorizations/factorizations.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@ export eig, eig!, eigh, eigh!
77
export tsvd, tsvd!, svdvals, svdvals!
88
export leftorth, leftorth!, rightorth, rightorth!
99
export leftnull, leftnull!, rightnull, rightnull!
10-
export copy_oftype, permutedcopy_oftype
10+
export copy_oftype, permutedcopy_oftype, one!
1111
export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace
1212

1313
using ..TensorKit
14-
using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock
15-
using ..MatrixAlgebra: MatrixAlgebra
14+
using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock, one!
1615

1716
using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!
1817
import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian
@@ -42,6 +41,8 @@ include("matrixalgebrakit.jl")
4241
include("truncation.jl")
4342
include("deprecations.jl")
4443

44+
TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A)
45+
4546
function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
4647
t = permute(t, (p₁, p₂); copy=false)
4748
return isisometry(t)
@@ -112,7 +113,7 @@ function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD})
112113
V = zerovector!(similar(b.diag, lb, lb))
113114
p = sortperm(b.diag; by=abs, rev=true)
114115
for (i, pi) in enumerate(p)
115-
U[pi, i] = MatrixAlgebra.safesign(b.diag[pi])
116+
U[pi, i] = safesign(b.diag[pi])
116117
V[i, pi] = 1
117118
end
118119
Σ = abs.(view(b.diag, p))

src/tensors/factorizations/implementations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ _kindof(::Union{QR,QRpos}) = :qr
33
_kindof(::Union{LQ,LQpos}) = :lq
44
_kindof(::Polar) = :polar
55

6-
leftorth!(t::AbstractTensorMap; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...)
6+
leftorth!(t; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...)
77

88
function _leftorth!(t::AbstractTensorMap, alg::Nothing, ; kwargs...)
99
return isempty(kwargs) ? left_orth!(t) : left_orth!(t; trunc=(; kwargs...))

0 commit comments

Comments
 (0)