Skip to content

Commit d80aab3

Browse files
committed
clean up handling DiagonalTensorMap
1 parent c06cb6c commit d80aab3

File tree

3 files changed

+81
-55
lines changed

3 files changed

+81
-55
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Combinatorics = "1"
3333
FiniteDifferences = "0.12"
3434
LRUCache = "1.0.2"
3535
LinearAlgebra = "1"
36-
MatrixAlgebraKit = "0.3"
36+
MatrixAlgebraKit = "0.3.1"
3737
OhMyThreads = "0.8.0"
3838
PackageExtensionCompat = "1"
3939
Random = "1"
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# DiagonalTensorMap
2+
# -----------------
3+
_repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data)
4+
5+
for f in [
6+
:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null,
7+
:lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full,
8+
:eigh_trunc, :eigh_vals, :left_polar, :right_polar,
9+
]
10+
@eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d)
11+
end
12+
13+
for f! in (:qr_full!, :qr_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!)
14+
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm)
15+
return d, similar(d)
16+
end
17+
end
18+
for f! in (:lq_full!, :lq_compact!)
19+
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm)
20+
return similar(d), d
21+
end
22+
end
23+
24+
for f! in (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!)
25+
@eval function $f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm)
26+
check_input($f!, d, F, alg)
27+
$f!(_repack_diagonal(d), _repack_diagonal.(F), alg)
28+
return F
29+
end
30+
end
31+
32+
for f! in (:qr_full!, :qr_compact!)
33+
@eval function check_input(::typeof($f!), d::AbstractTensorMap, (Q, R)::_T_QR, ::DiagonalAlgorithm)
34+
@assert d isa DiagonalTensorMap
35+
@assert Q isa DiagonalTensorMap && R isa DiagonalTensorMap
36+
@check_scalar Q d
37+
@check_scalar R d
38+
@check_space(Q, space(d))
39+
@check_space(R, space(d))
40+
41+
return nothing
42+
end
43+
end
44+
45+
for f! in (:lq_full!, :lq_compact!)
46+
@eval function check_input(::typeof($f!), d::AbstractTensorMap, (L, Q)::_T_LQ, ::DiagonalAlgorithm)
47+
@assert d isa DiagonalTensorMap
48+
@assert Q isa DiagonalTensorMap && L isa DiagonalTensorMap
49+
@check_scalar Q d
50+
@check_scalar L d
51+
@check_space(Q, space(d))
52+
@check_space(L, space(d))
53+
54+
return nothing
55+
end
56+
end
57+
58+
# f_vals
59+
# ------
60+
61+
for f! in (:eig_vals!, :eigh_vals!, :svd_vals!)
62+
@eval function $f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm)
63+
check_input($f!, d, V, alg)
64+
$f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg)
65+
return V
66+
end
67+
@eval function initialize_output(::typeof($f!), d::DiagonalTensorMap, alg::DiagonalAlgorithm)
68+
data = initialize_output($f!, _repack_diagonal(d), alg)
69+
return DiagonalTensorMap(data, d.domain)
70+
end
71+
end
72+
73+
function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap,
74+
::DiagonalAlgorithm)
75+
@check_scalar D t
76+
@check_space D space(t)
77+
return nothing
78+
end

src/tensors/factorizations/factorizations.jl

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using TensorOperations: Index2Tuple
2121
using MatrixAlgebraKit
2222
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy,
2323
NoTruncation, TruncationKeepAbove, TruncationKeepBelow,
24-
TruncationIntersection, TruncationKeepFiltered
24+
TruncationIntersection, TruncationKeepFiltered, DiagonalAlgorithm
2525
import MatrixAlgebraKit: default_algorithm,
2626
copy_input, check_input, initialize_output,
2727
qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!,
@@ -41,6 +41,7 @@ include("matrixalgebrakit.jl")
4141
include("truncation.jl")
4242
include("deprecations.jl")
4343
include("adjoint.jl")
44+
include("diagonal.jl")
4445

4546
TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A)
4647

@@ -55,59 +56,6 @@ end
5556
#------------------------------------------------------------------------------------------
5657
const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}
5758

58-
# DiagonalTensorMap
59-
# -----------------
60-
function leftorth!(d::DiagonalTensorMap; alg=QR(), kwargs...)
61-
@assert alg isa Union{QR,QL}
62-
return one(d), d # TODO: this is only correct for `alg = QR()` or `alg = QL()`
63-
end
64-
function rightorth!(d::DiagonalTensorMap; alg=LQ(), kwargs...)
65-
@assert alg isa Union{LQ,RQ}
66-
return d, one(d) # TODO: this is only correct for `alg = LQ()` or `alg = RQ()`
67-
end
68-
leftnull!(d::DiagonalTensorMap; kwargs...) = leftnull!(TensorMap(d); kwargs...)
69-
rightnull!(d::DiagonalTensorMap; kwargs...) = rightnull!(TensorMap(d); kwargs...)
70-
71-
function tsvd!(d::DiagonalTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
72-
return _tsvd!(d, alg, trunc, p)
73-
end
74-
75-
# helper function
76-
function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD})
77-
InnerProductStyle(d) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!)
78-
I = sectortype(d)
79-
dims = SectorDict{I,Int}()
80-
generator = Base.Iterators.map(blocks(d)) do (c, b)
81-
lb = length(b.diag)
82-
U = zerovector!(similar(b.diag, lb, lb))
83-
V = zerovector!(similar(b.diag, lb, lb))
84-
p = sortperm(b.diag; by=abs, rev=true)
85-
for (i, pi) in enumerate(p)
86-
U[pi, i] = safesign(b.diag[pi])
87-
V[i, pi] = 1
88-
end
89-
Σ = abs.(view(b.diag, p))
90-
dims[c] = lb
91-
return c => (U, Σ, V)
92-
end
93-
SVDdata = SectorDict(generator)
94-
return SVDdata, dims
95-
end
96-
97-
eig!(d::DiagonalTensorMap) = d, one(d)
98-
eigh!(d::DiagonalTensorMap{<:Real}) = d, one(d)
99-
eigh!(d::DiagonalTensorMap{<:Complex}) = DiagonalTensorMap(real(d.data), d.domain), one(d)
100-
101-
function LinearAlgebra.svdvals(d::DiagonalTensorMap)
102-
return SectorDict(c => LinearAlgebra.svdvals(b) for (c, b) in blocks(d))
103-
end
104-
function LinearAlgebra.eigvals(d::DiagonalTensorMap)
105-
return SectorDict(c => LinearAlgebra.eigvals(b) for (c, b) in blocks(d))
106-
end
107-
108-
function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2)
109-
return LinearAlgebra.cond(Diagonal(d.data), p)
110-
end
11159
#------------------------------#
11260
# Singular value decomposition #
11361
#------------------------------#

0 commit comments

Comments
 (0)