Skip to content

Commit 98ec13a

Browse files
committed
Merge branch 'matrixalgebra' of https://github.com/Jutho/TensorKit.jl into matrixalgebra
2 parents d80aab3 + 0fbc9ec commit 98ec13a

File tree

8 files changed

+257
-196
lines changed

8 files changed

+257
-196
lines changed

src/tensors/factorizations/adjoint.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,13 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!)
6666
return F
6767
end
6868
end
69+
# avoid amgiguity
70+
function initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap,
71+
alg::TruncatedAlgorithm)
72+
return initialize_output(svd_compact!, t, alg.alg)
73+
end
74+
# to fix ambiguity
75+
function svd_trunc!(t::AdjointTensorMap, USVᴴ::Tuple{AdjointTensorMap,DiagonalTensorMap,AdjointTensorMap}, alg::TruncatedAlgorithm)
76+
USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg)
77+
return truncate!(svd_trunc!, USVᴴ′, alg.trunc)
78+
end

src/tensors/factorizations/factorizations.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ export tsvd, tsvd!, svdvals, svdvals!
88
export leftorth, leftorth!, rightorth, rightorth!
99
export leftnull, leftnull!, rightnull, rightnull!
1010
export copy_oftype, permutedcopy_oftype, one!
11-
export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace
11+
export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace, PolarViaSVD
12+
#export LAPACK_HouseholderQR, LAPACK_HouseholderLQ
1213

1314
using ..TensorKit
1415
using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock, one!
@@ -21,7 +22,9 @@ using TensorOperations: Index2Tuple
2122
using MatrixAlgebraKit
2223
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy,
2324
NoTruncation, TruncationKeepAbove, TruncationKeepBelow,
24-
TruncationIntersection, TruncationKeepFiltered, DiagonalAlgorithm
25+
TruncationIntersection, TruncationKeepFiltered, PolarViaSVD,
26+
LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR,
27+
LAPACK_HouseholderLQ, DiagonalAlgorithm
2528
import MatrixAlgebraKit: default_algorithm,
2629
copy_input, check_input, initialize_output,
2730
qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!,

src/tensors/factorizations/implementations.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ _kindof(::Union{QR,QRpos}) = :qr
33
_kindof(::Union{LQ,LQpos}) = :lq
44
_kindof(::Polar) = :polar
55

6+
_kindof(::LAPACK_HouseholderQR) = :qr
7+
_kindof(::LAPACK_HouseholderLQ) = :lq
8+
_kindof(::LAPACK_SVDAlgorithm) = :svd
9+
_kindof(::PolarViaSVD) = :polar
10+
611
leftorth!(t; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...)
712

813
function _leftorth!(t::AbstractTensorMap, alg::Nothing, ; kwargs...)
@@ -19,14 +24,16 @@ function _leftorth!(t::AbstractTensorMap, alg::Union{QL,QLpos}; kwargs...)
1924
return Q, R
2025
end
2126
end
22-
function _leftorth!(t, alg::OFA; kwargs...)
27+
function _leftorth!(t, alg::Union{OFA,AbstractAlgorithm}; kwargs...)
2328
trunc = isempty(kwargs) ? nothing : (; kwargs...)
2429

25-
Base.depwarn(lazy"$alg is deprecated", :leftorth!)
30+
alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :leftorth!)
2631

2732
kind = _kindof(alg)
2833
if kind == :svd
29-
alg_svd = alg === SVD() ? LAPACK_QRIteration() :
34+
alg_svd = alg === LAPACK_QRIteration() ? alg :
35+
alg === LAPACK_DivideAndConquer() ? alg :
36+
alg === SVD() ? LAPACK_QRIteration() :
3037
alg === SDD() ? LAPACK_DivideAndConquer() :
3138
throw(ArgumentError(lazy"Unknown algorithm $alg"))
3239
return left_orth!(t; kind, alg_svd, trunc)
@@ -40,19 +47,22 @@ function _leftorth!(t, alg::OFA; kwargs...)
4047
end
4148
end
4249
# fallback to MatrixAlgebraKit version
43-
_leftorth!(t, alg; kwargs...) = left_orth!(t; alg, kwargs...)
50+
_leftorth!(t, alg; kwargs...) = left_orth!(t, alg; kwargs...)
4451

4552
function leftnull!(t::AbstractTensorMap;
46-
alg::Union{QR,QRpos,SVD,SDD,Nothing}=nothing, kwargs...)
53+
alg::Union{LAPACK_HouseholderQR,LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,QR,QRpos,SVD,SDD,Nothing}=nothing, kwargs...)
4754
InnerProductStyle(t) === EuclideanInnerProduct() ||
4855
throw_invalid_innerproduct(:leftnull!)
4956
trunc = isempty(kwargs) ? nothing : (; kwargs...)
57+
alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :leftnull!)
5058

5159
isnothing(alg) && return left_null!(t; trunc)
5260

5361
kind = _kindof(alg)
5462
if kind == :svd
55-
alg_svd = alg === SVD() ? LAPACK_QRIteration() :
63+
alg_svd = alg === LAPACK_QRIteration() ? alg :
64+
alg === LAPACK_DivideAndConquer() ? alg :
65+
alg === SVD() ? LAPACK_QRIteration() :
5666
alg === SDD() ? LAPACK_DivideAndConquer() :
5767
throw(ArgumentError(lazy"Unknown algorithm $alg"))
5868
return left_null!(t; kind, alg_svd, trunc)
@@ -65,10 +75,12 @@ function leftnull!(t::AbstractTensorMap;
6575
end
6676

6777
function rightorth!(t::AbstractTensorMap;
68-
alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar,Nothing}=nothing, kwargs...)
78+
alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,RQ,RQpos,SVD,SDD,Polar,Nothing}=nothing, kwargs...)
6979
InnerProductStyle(t) === EuclideanInnerProduct() ||
7080
throw_invalid_innerproduct(:rightorth!)
7181
trunc = isempty(kwargs) ? nothing : (; kwargs...)
82+
83+
alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :rightorth!)
7284

7385
isnothing(alg) && return right_orth!(t; trunc)
7486

@@ -82,7 +94,9 @@ function rightorth!(t::AbstractTensorMap;
8294

8395
kind = _kindof(alg)
8496
if kind == :svd
85-
alg_svd = alg === SVD() ? LAPACK_QRIteration() :
97+
alg_svd = alg === LAPACK_QRIteration() ? alg :
98+
alg === LAPACK_DivideAndConquer() ? alg :
99+
alg === SVD() ? LAPACK_QRIteration() :
86100
alg === SDD() ? LAPACK_DivideAndConquer() :
87101
throw(ArgumentError(lazy"Unknown algorithm $alg"))
88102
return right_orth!(t; kind, alg_svd, trunc)
@@ -97,16 +111,20 @@ function rightorth!(t::AbstractTensorMap;
97111
end
98112

99113
function rightnull!(t::AbstractTensorMap;
100-
alg::Union{LQ,LQpos,SVD,SDD,Nothing}=nothing, kwargs...)
114+
alg::Union{LAPACK_HouseholderLQ, LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,SVD,SDD,Nothing}=nothing, kwargs...)
101115
InnerProductStyle(t) === EuclideanInnerProduct() ||
102116
throw_invalid_innerproduct(:rightnull!)
103117
trunc = isempty(kwargs) ? nothing : (; kwargs...)
104118

119+
alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :rightnull!)
120+
105121
isnothing(alg) && return right_null!(t; trunc)
106122

107123
kind = _kindof(alg)
108124
if kind == :svd
109-
alg_svd = alg === SVD() ? LAPACK_QRIteration() :
125+
alg_svd = alg === LAPACK_QRIteration() ? alg :
126+
alg === LAPACK_DivideAndConquer() ? alg :
127+
alg === SVD() ? LAPACK_QRIteration() :
110128
alg === SDD() ? LAPACK_DivideAndConquer() :
111129
throw(ArgumentError(lazy"Unknown algorithm $alg"))
112130
return right_null!(t; kind, alg_svd, trunc)

src/tensors/factorizations/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ equivalent total dimension of the internal vector space is no larger than `χ`.
3030
The method `tsvd` also returns the truncation error `ϵ`, computed as the `p` norm of the
3131
singular values that were truncated.
3232
33-
THe keyword `alg` can be equal to `SVD()` or `SDD()`, corresponding to the underlying LAPACK
33+
The keyword `alg` can be equal to `SVD()` or `SDD()`, corresponding to the underlying LAPACK
3434
algorithm that computes the decomposition (`_gesvd` or `_gesdd`).
3535
3636
Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)`

src/tensors/factorizations/matrixalgebrakit.jl

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ end
379379
# -------------------
380380
const _T_WP = Tuple{<:AbstractTensorMap,<:AbstractTensorMap}
381381
const _T_PWᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap}
382-
using MatrixAlgebraKit: PolarViaSVD
383382

384383
function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP,
385384
::AbstractAlgorithm)
@@ -556,15 +555,8 @@ function initialize_output(::typeof(right_null!), t::AbstractTensorMap)
556555
return N
557556
end
558557

559-
for f! in (:left_null_svd!, :right_null_svd!)
560-
@eval function $f!(t::AbstractTensorMap, N, alg, ::Nothing=nothing)
561-
foreachblock(t, N) do _, (b, n)
562-
n′ = $f!(b, n, alg)
563-
# deal with the case where the output is not the same as the input
564-
n === n′ || copyto!(n, n′)
565-
return nothing
566-
end
567-
568-
return N
558+
for (f!, f_svd!) in zip((:left_null!, :right_null!), (:left_null_svd!, :right_null_svd!))
559+
@eval function $f_svd!(t::AbstractTensorMap, N, alg, ::Nothing=nothing)
560+
return $f!(t, N; alg_svd=alg)
569561
end
570562
end

test/factorizations.jl

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
spacelist = try
2+
if ENV["CI"] == "true"
3+
println("Detected running on CI")
4+
if Sys.iswindows()
5+
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂)
6+
elseif Sys.isapple()
7+
(Vtr, Vℤ₃, VfU₁, VfSU₂)
8+
else
9+
(Vtr, VU₁, VCU₁, VSU₂, VfSU₂)
10+
end
11+
else
12+
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)
13+
end
14+
catch
15+
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)
16+
end
17+
18+
for V in spacelist
19+
I = sectortype(first(V))
20+
Istr = TensorKit.type_repr(I)
21+
println("---------------------------------------")
22+
println("Tensors with symmetry: $Istr")
23+
println("---------------------------------------")
24+
@timedtestset "Tensors with symmetry: $Istr" verbose = true begin
25+
V1, V2, V3, V4, V5 = V
26+
@timedtestset "Factorization" begin
27+
W = V1 V2
28+
@testset for T in (Float32, ComplexF64)
29+
# Test both a normal tensor and an adjoint one.
30+
ts = (rand(T, W, W'), rand(T, W, W')')
31+
@testset for t in ts
32+
# test squares and rectangles here
33+
@testset "leftorth with $alg" for alg in
34+
(TensorKit.LAPACK_HouseholderQR(),
35+
TensorKit.LAPACK_HouseholderQR(positive=true),
36+
#TensorKit.QL(),
37+
#TensorKit.QLpos(),
38+
TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()),
39+
TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()),
40+
TensorKit.LAPACK_QRIteration(),
41+
TensorKit.LAPACK_DivideAndConquer())
42+
Q, R = @constinferred leftorth(t; alg=alg)
43+
@test isisometry(Q)
44+
tQR = Q * R
45+
@test tQR t
46+
end
47+
@testset "leftnull with $alg" for alg in
48+
(TensorKit.LAPACK_HouseholderQR(),
49+
TensorKit.LAPACK_QRIteration(),
50+
TensorKit.LAPACK_DivideAndConquer())
51+
N = @constinferred leftnull(t; alg=alg)
52+
@test isisometry(N)
53+
@test norm(N' * t) < 100 * eps(norm(t))
54+
end
55+
@testset "rightorth with $alg" for alg in
56+
(#TensorKit.RQ(), TensorKit.RQpos(),
57+
TensorKit.LAPACK_HouseholderLQ(),
58+
TensorKit.LAPACK_HouseholderLQ(positive=true),
59+
TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()),
60+
TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()),
61+
TensorKit.LAPACK_QRIteration(),
62+
TensorKit.LAPACK_DivideAndConquer())
63+
L, Q = @constinferred rightorth(t; alg=alg)
64+
@test isisometry(Q; side=:right)
65+
@test L * Q t
66+
end
67+
@testset "rightnull with $alg" for alg in
68+
(TensorKit.LAPACK_HouseholderLQ(),
69+
TensorKit.LAPACK_QRIteration(),
70+
TensorKit.LAPACK_DivideAndConquer())
71+
M = @constinferred rightnull(t; alg=alg)
72+
@test isisometry(M; side=:right)
73+
@test norm(t * M') < 100 * eps(norm(t))
74+
end
75+
@testset "tsvd with $alg" for alg in (TensorKit.LAPACK_QRIteration(),
76+
TensorKit.LAPACK_DivideAndConquer())
77+
U, S, V = @constinferred tsvd(t; alg=alg)
78+
@test isisometry(U)
79+
@test isisometry(V; side=:right)
80+
@test U * S * V t
81+
82+
s = LinearAlgebra.svdvals(t)
83+
s′ = LinearAlgebra.diag(S)
84+
for (c, b) in s
85+
@test b s′[c]
86+
end
87+
s = LinearAlgebra.svdvals(t')
88+
s′ = LinearAlgebra.diag(S')
89+
for (c, b) in s
90+
@test b s′[c]
91+
end
92+
end
93+
@testset "cond and rank" begin
94+
d1 = dim(codomain(t))
95+
d2 = dim(domain(t))
96+
@test rank(t) == min(d1, d2)
97+
M = leftnull(t)
98+
@test rank(M) == max(d1, d2) - min(d1, d2)
99+
t3 = unitary(T, V1 V2, V1 V2)
100+
@test cond(t3) one(real(T))
101+
@test rank(t3) == dim(V1 V2)
102+
t4 = randn(T, V1 V2, V1 V2)
103+
t4 = (t4 + t4') / 2
104+
vals = LinearAlgebra.eigvals(t4)
105+
λmax = maximum(s -> maximum(abs, s), values(vals))
106+
λmin = minimum(s -> minimum(abs, s), values(vals))
107+
@test cond(t4) λmax / λmin
108+
vals = LinearAlgebra.eigvals(t4')
109+
λmax = maximum(s -> maximum(abs, s), values(vals))
110+
λmin = minimum(s -> minimum(abs, s), values(vals))
111+
@test cond(t4') λmax / λmin
112+
end
113+
end
114+
@testset "empty tensor" begin
115+
t = randn(T, V1 V2, zero(V1))
116+
@testset "leftorth with $alg" for alg in
117+
(TensorKit.LAPACK_HouseholderQR(),
118+
TensorKit.LAPACK_HouseholderQR(positive=true),
119+
#TensorKit.QL(), TensorKit.QLpos(),
120+
TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()),
121+
TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()),
122+
TensorKit.LAPACK_QRIteration(),
123+
TensorKit.LAPACK_DivideAndConquer())
124+
Q, R = @constinferred leftorth(t; alg=alg)
125+
@test Q == t
126+
@test dim(Q) == dim(R) == 0
127+
end
128+
@testset "leftnull with $alg" for alg in
129+
(TensorKit.LAPACK_HouseholderQR(),
130+
TensorKit.LAPACK_QRIteration(),
131+
TensorKit.LAPACK_DivideAndConquer())
132+
N = @constinferred leftnull(t; alg=alg)
133+
@test isunitary(N)
134+
end
135+
@testset "rightorth with $alg" for alg in
136+
(#TensorKit.RQ(), TensorKit.RQpos(),
137+
TensorKit.LAPACK_HouseholderLQ(),
138+
TensorKit.LAPACK_HouseholderLQ(positive=true),
139+
TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()),
140+
TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()),
141+
TensorKit.LAPACK_QRIteration(),
142+
TensorKit.LAPACK_DivideAndConquer())
143+
L, Q = @constinferred rightorth(copy(t'); alg=alg)
144+
@test Q == t'
145+
@test dim(Q) == dim(L) == 0
146+
end
147+
@testset "rightnull with $alg" for alg in
148+
(TensorKit.LAPACK_HouseholderLQ(),
149+
TensorKit.LAPACK_QRIteration(),
150+
TensorKit.LAPACK_DivideAndConquer())
151+
M = @constinferred rightnull(copy(t'); alg=alg)
152+
@test isunitary(M)
153+
end
154+
@testset "tsvd with $alg" for alg in (TensorKit.LAPACK_QRIteration(),
155+
TensorKit.LAPACK_DivideAndConquer())
156+
U, S, V = @constinferred tsvd(t; alg=alg)
157+
@test U == t
158+
@test dim(U) == dim(S) == dim(V)
159+
end
160+
@testset "cond and rank" begin
161+
@test rank(t) == 0
162+
W2 = zero(V1) * zero(V2)
163+
t2 = rand(W2, W2)
164+
@test rank(t2) == 0
165+
@test cond(t2) == 0.0
166+
end
167+
end
168+
@testset "eig and isposdef" begin
169+
t = rand(T, V1, V1)
170+
D, V = eigen(t)
171+
@test t * V V * D
172+
173+
d = LinearAlgebra.eigvals(t; sortby=nothing)
174+
d′ = LinearAlgebra.diag(D)
175+
for (c, b) in d
176+
@test b d′[c]
177+
end
178+
179+
# Somehow moving these test before the previous one gives rise to errors
180+
# with T=Float32 on x86 platforms. Is this an OpenBLAS issue?
181+
VdV = V' * V
182+
VdV = (VdV + VdV') / 2
183+
@test isposdef(VdV)
184+
185+
@test !isposdef(t) # unlikely for non-hermitian map
186+
t2 = (t + t')
187+
D, V = eigen(t2)
188+
@test isisometry(V)
189+
D̃, Ṽ = @constinferred eigh(t2)
190+
@test D
191+
@test V
192+
λ = minimum(minimum(real(LinearAlgebra.diag(b)))
193+
for (c, b) in blocks(D))
194+
@test cond(Ṽ) one(real(T))
195+
@test isposdef(t2) == isposdef(λ)
196+
@test isposdef(t2 - λ * one(t2) + 0.1 * one(t2))
197+
@test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2))
198+
end
199+
end
200+
end
201+
end
202+
end

0 commit comments

Comments
 (0)