Skip to content

Commit 2dffd1d

Browse files
committed
some simplifications and extensions
1 parent 80fe448 commit 2dffd1d

File tree

3 files changed

+56
-11
lines changed

3 files changed

+56
-11
lines changed

src/factorizations/matrixalgebrakit.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ end
8181

8282
function MAK.initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
8383
V_cod = infimum(fuse(codomain(t)), fuse(domain(t)))
84-
return diagview(DiagonalTensorMap{real(scalartype(t))}(undef, V_cod))
84+
T = real(scalartype(t))
85+
return SectorVector{T}(undef, V_cod)
8586
end
8687

8788
# Eigenvalue decomposition
@@ -105,13 +106,13 @@ end
105106
function MAK.initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
106107
V_D = fuse(domain(t))
107108
T = real(scalartype(t))
108-
return diagview(DiagonalTensorMap{Tc}(undef, V_D))
109+
return SectorVector{T}(undef, V_D)
109110
end
110111

111112
function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
112113
V_D = fuse(domain(t))
113114
Tc = complex(scalartype(t))
114-
return diagview(DiagonalTensorMap{Tc}(undef, V_D))
115+
return SectorVector{Tc}(undef, V_cod)
115116
end
116117

117118
# QR decomposition

src/factorizations/truncation.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ end
7878
function MAK.truncate(
7979
::typeof(left_null!), (U, S)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy
8080
)
81-
extended_S = SectorVector{eltype(S)}(undef, fuse(codomain(U)))
82-
fill!(extended_S.data, zero(eltype(extended_S.data)))
81+
extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(codomain(U))))
8382
for (c, b) in blocks(S)
8483
copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter
8584
end
@@ -92,8 +91,7 @@ end
9291
function MAK.truncate(
9392
::typeof(right_null!), (S, Vᴴ)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy
9493
)
95-
extended_S = SectorVector{eltype(S)}(undef, fuse(domain(Vᴴ)))
96-
fill!(extended_S.data, zero(eltype(extended_S.data)))
94+
extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(domain(Vᴴ))))
9795
for (c, b) in blocks(S)
9896
copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter
9997
end
@@ -184,12 +182,12 @@ function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false)
184182
perms = SectorDict(
185183
(
186184
begin
187-
p = sortperm(d; by, rev)
185+
p = sortperm(v; by, rev)
188186
vs = values_sorted[c]
189-
vs .= d[p]
187+
vs .= view(v, p)
190188
c => p
191189
end
192-
) for (c, d) in pairs(values)
190+
) for (c, v) in pairs(values)
193191
)
194192
return values_sorted, perms
195193
end
@@ -300,5 +298,5 @@ function MAK.truncation_error!(values::SectorVector, ind)
300298
v = values[c]
301299
v[ind_c] .= zero(eltype(v))
302300
end
303-
return TensorKit._norm(pairs(values), 2, zero(real(eltype(valtype(values)))))
301+
return norm(values)
304302
end

src/tensors/sectorvector.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,49 @@ Base.similar(v::SectorVector, V::ElementarySpace) = SectorVector(undef, V)
5757
blocksectors(v::SectorVector) = keys(v)
5858
blocks(v::SectorVector) = pairs(v)
5959
block(v::SectorVector{T, I, A}, c::I) where {T, I, A} = Base.getindex(v, c)
60+
61+
# VectorInterface and LinearAlgebra interface
62+
# ----------------------------------------------
63+
VectorInterface.zerovector(v::SectorVector, ::Type{T}) where {T} =
64+
SectorVector(zerovector(parent(v), T), v.structure)
65+
VectorInterface.zerovector!(v::SectorVector) = (zerovector!(parent(v)); return v)
66+
VectorInterface.zerovector!!(v::SectorVector) = (zerovector!!(parent(v)); return v)
67+
68+
VectorInterface.scale(v::SectorVector, α) = SectorVector(scale(parent(v), α), v.structure)
69+
VectorInterface.scale!(v::SectorVector, α) = (scale!(parent(v), α); return v)
70+
VectorInterface.scale!!(v::SectorVector, α) = (scale!!(parent(v), α); return v)
71+
72+
function VectorInterface.add(v1::SectorVector, v2::SectorVector, α = One(), β = One())
73+
SectorVector(add(parent(v1), parent(v2), α, β), v1.structure)
74+
end
75+
function VectorInterface.add!(v1::SectorVector, v2::SectorVector, α = One(), β = One())
76+
add!(parent(v1), parent(v2), α, β)
77+
return v1
78+
end
79+
function VectorInterface.add!!(v1::SectorVector, v2::SectorVector, α = One(), β = One())
80+
add!!(parent(v1), parent(v2), α, β)
81+
return v1
82+
end
83+
84+
function VectorInterface.inner(v1::SectorVector, v2::SectorVector)
85+
v1.structure == v2.structure || throw(SpaceMismatch("Sector structures do not match"))
86+
I = sectortype(v1)
87+
if FusionStyle(I) isa UniqueFusion # all quantum dimensions are one
88+
return inner(parent(v1), parent(v2))
89+
else
90+
T = VectorInterface.promote_inner(v1, v2)
91+
s = zero(T)
92+
for c in blocksectors(v1)
93+
b1 = block(v1, c)
94+
b2 = block(v2, c)
95+
s += convert(T, dim(c)) * inner(b1, b2)
96+
end
97+
end
98+
return s
99+
end
100+
101+
LinearAlgebra.dot(v1::SectorVector, v2::SectorVector) = inner(v1, v2)
102+
103+
function LinearAlgebra.norm(v::SectorVector, p::Real = 2)
104+
return _norm(blocks(v), p, float(zero(real(scalartype(v)))))
105+
end

0 commit comments

Comments
 (0)