Skip to content

Commit 3d65453

Browse files
lkdvoskshyattJutho
authored
Updates for MatrixAlgebraKit v0.6 (#312)
* bump MatrixAlgebraKit compat to 0.6 * update `ishermitian` implementations * update `isisometric` implementations * add projections * Incremental fixes * orthnull progress * Adjoint factorizations (AD not working) * add missing MAK. * remove `eig(::Adjoint)` This doesn't actually hold for non-hermitian matrices * add truncation error implementation * add nullspace truncation specializations * update orth interface in tests * fix wrong variable * improve adjoint support * remove unnecessary left/rightorth functions * pass kwargs in isisometri * add tests for truncation error * more nullspace tests and fixes * remove unnecessary specialization * add missing argument * remove unnecessary specializations * slight reorganization * export new functionality * simplify truncationerror implementation * fix testcase * refrain from defining left_null functions to avoid ambiguities * fix ad tests * add projection tests * temporary AD fixes * try and add back some AD tests * apply code suggestions * stabilize AD test * some changes * more updates: formatting and deprecations * some more review updates * handle fully-truncated blocks * fix diagonal adjoint implementation * fix property name for PolarViaSVD * copy_input uses `f` instead of `f!` * remove more implementations to fix code * remove `check_input` --------- Co-authored-by: Katharine Hyatt <[email protected]> Co-authored-by: Katharine Hyatt <[email protected]> Co-authored-by: Jutho Haegeman <[email protected]>
1 parent a5811da commit 3d65453

File tree

14 files changed

+336
-782
lines changed

14 files changed

+336
-782
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Combinatorics = "1"
3535
FiniteDifferences = "0.12"
3636
LRUCache = "1.0.2"
3737
LinearAlgebra = "1"
38-
MatrixAlgebraKit = "0.5.0"
38+
MatrixAlgebraKit = "0.6.0"
3939
OhMyThreads = "0.8.0"
4040
PackageExtensionCompat = "1"
4141
Printf = "1"

src/TensorKit.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ export left_orth, right_orth, left_null, right_null,
7979
qr_full!, qr_compact!, qr_null!, lq_full!, lq_compact!, lq_null!,
8080
svd_compact!, svd_full!, svd_trunc!, svd_compact, svd_full, svd_trunc,
8181
exp, exp!,
82-
eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eig_full!, eig_full, eig_trunc!,
83-
eig_trunc,
84-
eigh_vals!, eigh_vals, eig_vals!, eig_vals,
85-
isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond
82+
eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eigh_vals!, eigh_vals,
83+
eig_full!, eig_full, eig_trunc!, eig_trunc, eig_vals!, eig_vals,
84+
ishermitian, project_hermitian, project_hermitian!,
85+
isantihermitian, project_antihermitian, project_antihermitian!,
86+
isisometric, isunitary, project_isometric, project_isometric!,
87+
isposdef, isposdef!, sylvester, rank, cond
8688

8789
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
8890
repartition!
@@ -135,7 +137,7 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr,
135137
adjoint, adjoint!, transpose, transpose!,
136138
lu, pinv, sylvester,
137139
eigen, eigen!, svd, svd!,
138-
isposdef, isposdef!, ishermitian, rank, cond,
140+
isposdef, isposdef!, rank, cond,
139141
Diagonal, Hermitian
140142
using MatrixAlgebraKit
141143

src/auxiliary/deprecate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ function tsvd(t::AbstractTensorMap; kwargs...)
186186
Base.depwarn("p is a deprecated kwarg, and should be specified through the truncation strategy", :tsvd)
187187
kwargs = _drop_p(; kwargs...)
188188
end
189-
return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : svd_compact(t; kwargs...)
189+
return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : (svd_compact(t; kwargs...)..., abs(zero(scalartype(t))))
190190
end
191191
function tsvd!(t::AbstractTensorMap; kwargs...)
192192
Base.depwarn("`tsvd!` is deprecated, use `svd_compact!`, `svd_full!` or `svd_trunc!` instead", :tsvd!)

src/factorizations/adjoint.jl

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,46 +6,54 @@ _adjoint(alg::MAK.LAPACK_HouseholderQR) = MAK.LAPACK_HouseholderLQ(; alg.kwargs.
66
_adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs...)
77
_adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...)
88
_adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...)
9-
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svdalg))
9+
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg))
1010
_adjoint(alg::AbstractAlgorithm) = alg
1111

12-
# 1-arg functions
13-
function MAK.initialize_output(::typeof(left_null!), t::AdjointTensorMap, alg::AbstractAlgorithm)
14-
return adjoint(MAK.initialize_output(right_null!, adjoint(t), _adjoint(alg)))
15-
end
16-
function MAK.initialize_output(
17-
::typeof(right_null!), t::AdjointTensorMap,
18-
alg::AbstractAlgorithm
19-
)
20-
return adjoint(MAK.initialize_output(left_null!, adjoint(t), _adjoint(alg)))
12+
for f in
13+
[
14+
:svd_compact, :svd_full, :svd_vals,
15+
:qr_compact, :qr_full, :qr_null,
16+
:lq_compact, :lq_full, :lq_null,
17+
:eig_full, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals,
18+
:left_polar, :right_polar,
19+
:project_hermitian, :project_antihermitian, :project_isometric,
20+
]
21+
f! = Symbol(f, :!)
22+
# just return the algorithm for the parent type since we are mapping this with
23+
# `_adjoint` afterwards anyways.
24+
# TODO: properly handle these cases
25+
@eval MAK.default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T <: AdjointTensorMap} =
26+
MAK.default_algorithm($f!, TensorKit.parenttype(T); kwargs...)
2127
end
2228

23-
function MAK.left_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm)
24-
right_null!(adjoint(t), adjoint(N), _adjoint(alg))
25-
return N
26-
end
27-
function MAK.right_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm)
28-
left_null!(adjoint(t), adjoint(N), _adjoint(alg))
29-
return N
30-
end
29+
# 1-arg functions
30+
MAK.initialize_output(::typeof(qr_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) =
31+
adjoint(MAK.initialize_output(lq_null!, adjoint(t), _adjoint(alg)))
32+
MAK.initialize_output(::typeof(lq_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) =
33+
adjoint(MAK.initialize_output(qr_null!, adjoint(t), _adjoint(alg)))
3134

32-
function MAK.is_left_isometry(t::AdjointTensorMap; kwargs...)
33-
return is_right_isometry(adjoint(t); kwargs...)
34-
end
35-
function MAK.is_right_isometry(t::AdjointTensorMap; kwargs...)
36-
return is_left_isometry(adjoint(t); kwargs...)
37-
end
35+
MAK.qr_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) =
36+
lq_null!(adjoint(t), adjoint(N), _adjoint(alg))
37+
MAK.lq_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) =
38+
qr_null!(adjoint(t), adjoint(N), _adjoint(alg))
39+
40+
MAK.is_left_isometric(t::AdjointTensorMap; kwargs...) =
41+
MAK.is_right_isometric(adjoint(t); kwargs...)
42+
MAK.is_right_isometric(t::AdjointTensorMap; kwargs...) =
43+
MAK.is_left_isometric(adjoint(t); kwargs...)
3844

3945
# 2-arg functions
40-
for (left_f!, right_f!) in zip(
41-
(:qr_full!, :qr_compact!, :left_polar!, :left_orth!),
42-
(:lq_full!, :lq_compact!, :right_polar!, :right_orth!)
46+
for (left_f, right_f) in zip(
47+
(:qr_full, :qr_compact, :left_polar),
48+
(:lq_full, :lq_compact, :right_polar)
4349
)
44-
@eval function MAK.copy_input(::typeof($left_f!), t::AdjointTensorMap)
45-
return adjoint(MAK.copy_input($right_f!, adjoint(t)))
50+
left_f! = Symbol(left_f, :!)
51+
right_f! = Symbol(right_f, :!)
52+
@eval function MAK.copy_input(::typeof($left_f), t::AdjointTensorMap)
53+
return adjoint(MAK.copy_input($right_f, adjoint(t)))
4654
end
47-
@eval function MAK.copy_input(::typeof($right_f!), t::AdjointTensorMap)
48-
return adjoint(MAK.copy_input($left_f!, adjoint(t)))
55+
@eval function MAK.copy_input(::typeof($right_f), t::AdjointTensorMap)
56+
return adjoint(MAK.copy_input($left_f, adjoint(t)))
4957
end
5058

5159
@eval function MAK.initialize_output(
@@ -60,29 +68,31 @@ for (left_f!, right_f!) in zip(
6068
end
6169

6270
@eval function MAK.$left_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
63-
$right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
64-
return F
71+
F′ = $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
72+
return reverse(adjoint.(F′))
6573
end
6674
@eval function MAK.$right_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
67-
$left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
68-
return F
75+
F′ = $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
76+
return reverse(adjoint.(F′))
6977
end
7078
end
7179

7280
# 3-arg functions
73-
for f! in (:svd_full!, :svd_compact!, :svd_trunc!)
74-
@eval function MAK.copy_input(::typeof($f!), t::AdjointTensorMap)
75-
return adjoint(MAK.copy_input($f!, adjoint(t)))
81+
for f in (:svd_full, :svd_compact)
82+
f! = Symbol(f, :!)
83+
@eval function MAK.copy_input(::typeof($f), t::AdjointTensorMap)
84+
return adjoint(MAK.copy_input($f, adjoint(t)))
7685
end
7786

7887
@eval function MAK.initialize_output(
7988
::typeof($f!), t::AdjointTensorMap, alg::AbstractAlgorithm
8089
)
8190
return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg))))
8291
end
92+
8393
@eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
84-
$f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
85-
return F
94+
F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
95+
return reverse(adjoint.(F′))
8696
end
8797

8898
# disambiguate by prohibition
@@ -92,17 +102,9 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!)
92102
throw(MethodError($f!, (t, alg)))
93103
end
94104
end
105+
95106
# avoid amgiguity
96-
function MAK.initialize_output(
97-
::typeof(svd_trunc!), t::AdjointTensorMap, alg::TruncatedAlgorithm
98-
)
99-
return MAK.initialize_output(svd_compact!, t, alg.alg)
100-
end
101-
# to fix ambiguity
102-
function MAK.svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm)
103-
USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg)
104-
return MAK.truncate(svd_trunc!, USVᴴ′, alg.trunc)
105-
end
106-
function MAK.svd_compact!(t::AdjointTensorMap, USVᴴ, alg::DiagonalAlgorithm)
107-
return MAK.svd_compact!(t, USVᴴ, alg.alg)
107+
function MAK.svd_compact!(t::AdjointTensorMap, F, alg::DiagonalAlgorithm)
108+
F′ = svd_compact!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
109+
return reverse(adjoint.(F′))
108110
end

src/factorizations/diagonal.jl

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -81,55 +81,20 @@ for f! in
8181
:eigh_trunc!, :right_orth!, :left_orth!,
8282
)
8383
@eval function MAK.$f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm)
84-
MAK.check_input($f!, d, F, alg)
8584
$f!(_repack_diagonal(d), _repack_diagonal.(F), alg)
8685
return F
8786
end
8887
end
8988

90-
for f! in (:qr_full!, :qr_compact!)
91-
@eval function MAK.check_input(
92-
::typeof($f!), d::AbstractTensorMap, QR, ::DiagonalAlgorithm
93-
)
94-
Q, R = QR
95-
@assert d isa DiagonalTensorMap
96-
@assert Q isa DiagonalTensorMap && R isa DiagonalTensorMap
97-
@check_scalar Q d
98-
@check_scalar R d
99-
@check_space(Q, space(d))
100-
@check_space(R, space(d))
101-
102-
return nothing
103-
end
104-
end
105-
106-
for f! in (:lq_full!, :lq_compact!)
107-
@eval function MAK.check_input(
108-
::typeof($f!), d::AbstractTensorMap, LQ, ::DiagonalAlgorithm
109-
)
110-
L, Q = LQ
111-
@assert d isa DiagonalTensorMap
112-
@assert Q isa DiagonalTensorMap && L isa DiagonalTensorMap
113-
@check_scalar Q d
114-
@check_scalar L d
115-
@check_space(Q, space(d))
116-
@check_space(L, space(d))
117-
118-
return nothing
119-
end
120-
end
121-
12289
# disambiguate
12390
function MAK.svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm)
12491
return svd_full!(t, USVᴴ, alg)
12592
end
12693

12794
# f_vals
12895
# ------
129-
13096
for f! in (:eig_vals!, :eigh_vals!, :svd_vals!)
13197
@eval function MAK.$f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm)
132-
MAK.check_input($f!, d, V, alg)
13398
$f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg)
13499
return V
135100
end
@@ -140,64 +105,3 @@ for f! in (:eig_vals!, :eigh_vals!, :svd_vals!)
140105
return DiagonalTensorMap(data, d.domain)
141106
end
142107
end
143-
144-
function MAK.check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::DiagonalAlgorithm)
145-
domain(t) == codomain(t) ||
146-
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
147-
148-
D, V = DV
149-
150-
@assert D isa DiagonalTensorMap
151-
@assert V isa AbstractTensorMap
152-
153-
# scalartype checks
154-
@check_scalar D t
155-
@check_scalar V t
156-
157-
# space checks
158-
@check_space D space(t)
159-
@check_space V space(t)
160-
161-
return nothing
162-
end
163-
164-
function MAK.check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::DiagonalAlgorithm)
165-
domain(t) == codomain(t) ||
166-
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
167-
168-
D, V = DV
169-
170-
@assert D isa DiagonalTensorMap
171-
@assert V isa AbstractTensorMap
172-
173-
# scalartype checks
174-
@check_scalar D t real
175-
@check_scalar V t
176-
177-
# space checks
178-
@check_space D space(t)
179-
@check_space V space(t)
180-
181-
return nothing
182-
end
183-
184-
function MAK.check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
185-
@assert D isa DiagonalTensorMap
186-
@check_scalar D t
187-
@check_space D space(t)
188-
return nothing
189-
end
190-
191-
function MAK.check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
192-
@assert D isa DiagonalTensorMap
193-
@check_scalar D t real
194-
@check_space D space(t)
195-
return nothing
196-
end
197-
198-
function MAK.check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
199-
@assert D isa DiagonalTensorMap
200-
@check_scalar D t real
201-
@check_space D space(t)
202-
return nothing
203-
end

src/factorizations/factorizations.jl

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ import MatrixAlgebraKit as MAK
1818
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm
1919
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue,
2020
TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder
21-
using MatrixAlgebraKit: left_orth_polar!, right_orth_polar!, left_orth_svd!,
22-
right_orth_svd!, left_null_svd!, right_null_svd!, diagview
21+
using MatrixAlgebraKit: diagview
2322

2423
include("utility.jl")
2524
include("matrixalgebrakit.jl")
@@ -30,11 +29,6 @@ include("pullbacks.jl")
3029

3130
TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A)
3231

33-
function MatrixAlgebraKit.isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
34-
t = permute(t, (p₁, p₂); copy = false)
35-
return isisometry(t)
36-
end
37-
3832
#------------------------------#
3933
# LinearAlgebra overloads
4034
#------------------------------#
@@ -61,38 +55,42 @@ LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t))
6155
#--------------------------------------------------#
6256
# Checks for hermiticity and positive definiteness #
6357
#--------------------------------------------------#
64-
function LinearAlgebra.ishermitian(t::AbstractTensorMap)
65-
domain(t) == codomain(t) || return false
66-
InnerProductStyle(t) === EuclideanInnerProduct() || return false # hermiticity only defined for euclidean
67-
for (c, b) in blocks(t)
68-
ishermitian(b) || return false
58+
function _blockmap(f; kwargs...)
59+
return function ((c, b))
60+
return f(b; kwargs...)
6961
end
70-
return true
7162
end
7263

64+
function MAK.ishermitian(t::AbstractTensorMap; kwargs...)
65+
return InnerProductStyle(t) === EuclideanInnerProduct() &&
66+
domain(t) == codomain(t) &&
67+
all(_blockmap(MAK.ishermitian; kwargs...), blocks(t))
68+
end
69+
function MAK.isantihermitian(t::AbstractTensorMap; kwargs...)
70+
return InnerProductStyle(t) === EuclideanInnerProduct() &&
71+
domain(t) == codomain(t) &&
72+
all(_blockmap(MAK.isantihermitian; kwargs...), blocks(t))
73+
end
74+
LinearAlgebra.ishermitian(t::AbstractTensorMap) = MAK.ishermitian(t)
75+
7376
function LinearAlgebra.isposdef(t::AbstractTensorMap)
7477
return isposdef!(copy_oftype(t, factorisation_scalartype(isposdef, t)))
7578
end
7679
function LinearAlgebra.isposdef!(t::AbstractTensorMap)
7780
domain(t) == codomain(t) ||
7881
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
7982
InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false
80-
for (c, b) in blocks(t)
81-
isposdef!(b) || return false
82-
end
83-
return true
83+
return all(_blockmap(isposdef!), blocks(t))
8484
end
8585

8686
# TODO: tolerances are per-block, not global or weighted - does that matter?
87-
function MatrixAlgebraKit.is_left_isometry(t::AbstractTensorMap; kwargs...)
87+
function MAK.is_left_isometric(t::AbstractTensorMap; kwargs...)
8888
domain(t) codomain(t) || return false
89-
f((c, b)) = MatrixAlgebraKit.is_left_isometry(b; kwargs...)
90-
return all(f, blocks(t))
89+
return all(_blockmap(MAK.is_left_isometric; kwargs...), blocks(t))
9190
end
92-
function MatrixAlgebraKit.is_right_isometry(t::AbstractTensorMap; kwargs...)
91+
function MAK.is_right_isometric(t::AbstractTensorMap; kwargs...)
9392
domain(t) codomain(t) || return false
94-
f((c, b)) = MatrixAlgebraKit.is_right_isometry(b; kwargs...)
95-
return all(f, blocks(t))
93+
return all(_blockmap(MAK.is_right_isometric; kwargs...), blocks(t))
9694
end
9795

9896
end

0 commit comments

Comments
 (0)