@@ -21,7 +21,7 @@ using TensorOperations: Index2Tuple
2121using MatrixAlgebraKit
2222using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy,
2323 NoTruncation, TruncationKeepAbove, TruncationKeepBelow,
24- TruncationIntersection, TruncationKeepFiltered
24+ TruncationIntersection, TruncationKeepFiltered, DiagonalAlgorithm
2525import MatrixAlgebraKit: default_algorithm,
2626 copy_input, check_input, initialize_output,
2727 qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!,
@@ -41,6 +41,7 @@ include("matrixalgebrakit.jl")
4141include (" truncation.jl" )
4242include (" deprecations.jl" )
4343include (" adjoint.jl" )
44+ include (" diagonal.jl" )
4445
4546TensorKit. one! (A:: AbstractMatrix ) = MatrixAlgebraKit. one! (A)
4647
5556# ------------------------------------------------------------------------------------------
5657const RealOrComplexFloat = Union{AbstractFloat,Complex{<: AbstractFloat }}
5758
58- # DiagonalTensorMap
59- # -----------------
60- function leftorth! (d:: DiagonalTensorMap ; alg= QR (), kwargs... )
61- @assert alg isa Union{QR,QL}
62- return one (d), d # TODO : this is only correct for `alg = QR()` or `alg = QL()`
63- end
64- function rightorth! (d:: DiagonalTensorMap ; alg= LQ (), kwargs... )
65- @assert alg isa Union{LQ,RQ}
66- return d, one (d) # TODO : this is only correct for `alg = LQ()` or `alg = RQ()`
67- end
68- leftnull! (d:: DiagonalTensorMap ; kwargs... ) = leftnull! (TensorMap (d); kwargs... )
69- rightnull! (d:: DiagonalTensorMap ; kwargs... ) = rightnull! (TensorMap (d); kwargs... )
70-
71- function tsvd! (d:: DiagonalTensorMap ; trunc= NoTruncation (), p:: Real = 2 , alg= SDD ())
72- return _tsvd! (d, alg, trunc, p)
73- end
74-
75- # helper function
76- function _compute_svddata! (d:: DiagonalTensorMap , alg:: Union{SVD,SDD} )
77- InnerProductStyle (d) === EuclideanInnerProduct () || throw_invalid_innerproduct (:tsvd! )
78- I = sectortype (d)
79- dims = SectorDict {I,Int} ()
80- generator = Base. Iterators. map (blocks (d)) do (c, b)
81- lb = length (b. diag)
82- U = zerovector! (similar (b. diag, lb, lb))
83- V = zerovector! (similar (b. diag, lb, lb))
84- p = sortperm (b. diag; by= abs, rev= true )
85- for (i, pi ) in enumerate (p)
86- U[pi , i] = safesign (b. diag[pi ])
87- V[i, pi ] = 1
88- end
89- Σ = abs .(view (b. diag, p))
90- dims[c] = lb
91- return c => (U, Σ, V)
92- end
93- SVDdata = SectorDict (generator)
94- return SVDdata, dims
95- end
96-
97- eig! (d:: DiagonalTensorMap ) = d, one (d)
98- eigh! (d:: DiagonalTensorMap{<:Real} ) = d, one (d)
99- eigh! (d:: DiagonalTensorMap{<:Complex} ) = DiagonalTensorMap (real (d. data), d. domain), one (d)
100-
101- function LinearAlgebra. svdvals (d:: DiagonalTensorMap )
102- return SectorDict (c => LinearAlgebra. svdvals (b) for (c, b) in blocks (d))
103- end
104- function LinearAlgebra. eigvals (d:: DiagonalTensorMap )
105- return SectorDict (c => LinearAlgebra. eigvals (b) for (c, b) in blocks (d))
106- end
107-
108- function LinearAlgebra. cond (d:: DiagonalTensorMap , p:: Real = 2 )
109- return LinearAlgebra. cond (Diagonal (d. data), p)
110- end
11159# ------------------------------#
11260# Singular value decomposition #
11361# ------------------------------#
0 commit comments