Skip to content

Commit 412a45c

Browse files
lkdvosJutho
andauthored
Add SectorVector (#324)
* initial basic design SectorVector * some additional functionality * relax `foreachblock` signature * replace `SectorDict` with `SectorVector` for eig/svdvals * export `svd_vals` * clean up SectorVector design * small fix * add finitedifferences support * update changelog * some simplifications and extensions * some further fixes * some more fixes * update dates --------- Co-authored-by: Jutho Haegeman <[email protected]>
1 parent a93edc6 commit 412a45c

File tree

13 files changed

+199
-89
lines changed

13 files changed

+199
-89
lines changed

CITATION.cff

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors:
1010
title: "TensorKit.jl"
1111
version: "0.16.0"
1212
doi: "10.5281/zenodo.8421339"
13-
date-released: "2025-12-05"
13+
date-released: "2025-12-08"
1414
url: "https://github.com/QuantumKitHub/TensorKit.jl"
1515
preferred-citation:
1616
type: article

docs/src/Changelog.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec
2020

2121
## [Unreleased](https://github.com/QuantumKitHub/TensorKit.jl/compare/v0.16.0...HEAD)
2222

23-
## [0.16.0](https://github.com/QuantumKitHub/TensorKit.jl/releases/tag/v0.16.0) - 2025-12-05
23+
## [0.16.0](https://github.com/QuantumKitHub/TensorKit.jl/releases/tag/v0.16.0) - 2025-12-08
2424

2525
### Added
2626

@@ -38,6 +38,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec
3838
- Major documentation update/overhaul ([#289](https://github.com/QuantumKitHub/TensorKit.jl/pull/289))
3939
- Added symmetric tensor tutorial as appendix ([#316](https://github.com/QuantumKitHub/TensorKit.jl/pull/316))
4040
- Improved error messages throughout codebase ([#309](https://github.com/QuantumKitHub/TensorKit.jl/pull/309))
41+
- `eigvals` and `svdvals` now output `SectorVector` objects, which do behave as `AbstractVector` but also have the option of iterating the blocks through `Base.pairs`. ([#324](https://github.com/QuantumKitHub/TensorKit.jl/pull/309)
4142

4243
### Deprecated
4344

@@ -52,6 +53,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec
5253

5354
- Avoid unnecessary copy in `twist` for tensors with bosonic braiding ([#305](https://github.com/QuantumKitHub/TensorKit.jl/pull/305))
5455
- Small fixes and typos ([#295](https://github.com/QuantumKitHub/TensorKit.jl/pull/295))
56+
- `eig_vals`, `svd_vals`, etc now all output `SectorVector` objects instead of `DiagonalTensorMap`s, in line with how MatrixAlgebraKit returns `Vector`s instead of `Diagonal`s ([#324](https://github.com/QuantumKitHub/TensorKit.jl/pull/309)
5557

5658
## [0.15.3](https://github.com/QuantumKitHub/TensorKit.jl/releases/tag/v0.15.3) - 2025-10-30
5759

ext/TensorKitFiniteDifferencesExt.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module TensorKitFiniteDifferencesExt
22

33
using TensorKit
4-
using TensorKit: sqrtdim, invsqrtdim
4+
using TensorKit: sqrtdim, invsqrtdim, SectorVector
55
using VectorInterface: scale!
66
using FiniteDifferences
77

@@ -31,6 +31,25 @@ function FiniteDifferences.to_vec(t::DiagonalTensorMap)
3131
return x_vec, DiagonalTensorMap_from_vec
3232
end
3333

34+
function FiniteDifferences.to_vec(v::SectorVector{T, <:Sector}) where {T}
35+
v_normalized = similar(v)
36+
for (c, b) in pairs(v)
37+
scale!(v_normalized[c], b, sqrtdim(c))
38+
end
39+
vec = parent(v_normalized)
40+
vec_real = T <: Real ? vec : collect(reinterpret(real(T), vec))
41+
42+
function from_vec(x_real)
43+
x = T <: Real ? x_real : reinterpret(T, x_real)
44+
v_result = SectorVector(x, v.structure)
45+
for (c, b) in pairs(v_result)
46+
scale!(b, invsqrtdim(c))
47+
end
48+
return v_result
49+
end
50+
return vec_real, from_vec
51+
end
52+
3453
end
3554

3655
# TODO: Investigate why the approach below doesn't work

src/TensorKit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ export left_orth, right_orth, left_null, right_null,
7979
left_polar, left_polar!, right_polar, right_polar!,
8080
qr_full, qr_compact, qr_null, lq_full, lq_compact, lq_null,
8181
qr_full!, qr_compact!, qr_null!, lq_full!, lq_compact!, lq_null!,
82-
svd_compact!, svd_full!, svd_trunc!, svd_compact, svd_full, svd_trunc,
82+
svd_compact!, svd_full!, svd_trunc!, svd_compact, svd_full, svd_trunc, svd_vals, svd_vals!,
8383
exp, exp!,
8484
eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eigh_vals!, eigh_vals,
8585
eig_full!, eig_full, eig_trunc!, eig_trunc, eig_vals!, eig_vals,
@@ -222,6 +222,7 @@ end
222222
include("tensors/abstracttensor.jl")
223223
include("tensors/backends.jl")
224224
include("tensors/blockiterator.jl")
225+
include("tensors/sectorvector.jl")
225226
include("tensors/tensor.jl")
226227
include("tensors/adjoint.jl")
227228
include("tensors/linalg.jl")

src/factorizations/diagonal.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# -----------------
33
_repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data)
44

5+
MAK.diagview(t::DiagonalTensorMap) = SectorVector(t.data, TensorKit.diagonalblockstructure(space(t)))
6+
57
for f in (
68
:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null,
79
:lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full,

src/factorizations/factorizations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module Factorizations
66
export copy_oftype, factorisation_scalartype, one!, truncspace
77

88
using ..TensorKit
9-
using ..TensorKit: AdjointTensorMap, SectorDict, blocktype, foreachblock, one!
9+
using ..TensorKit: AdjointTensorMap, SectorDict, SectorVector, blocktype, foreachblock, one!
1010

1111
using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!, eigen, eigen!,
1212
isposdef, isposdef!, ishermitian
@@ -44,13 +44,13 @@ function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
4444
tcopy = copy_oftype(t, factorisation_scalartype(LinearAlgebra.eigen, t))
4545
return LinearAlgebra.eigvals!(tcopy; kwargs...)
4646
end
47-
LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = diagview(eig_vals!(t))
47+
LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = eig_vals!(t)
4848

4949
function LinearAlgebra.svdvals(t::AbstractTensorMap)
5050
tcopy = copy_oftype(t, factorisation_scalartype(svd_vals!, t))
5151
return LinearAlgebra.svdvals!(tcopy)
5252
end
53-
LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t))
53+
LinearAlgebra.svdvals!(t::AbstractTensorMap) = svd_vals!(t)
5454

5555
#--------------------------------------------------#
5656
# Checks for hermiticity and positive definiteness #

src/factorizations/matrixalgebrakit.jl

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ for f! in (
4444
end
4545

4646
# Handle these separately because single output instead of tuple
47-
for f! in (:qr_null!, :lq_null!, :project_hermitian!, :project_antihermitian!, :project_isometric!)
47+
for f! in (
48+
:qr_null!, :lq_null!,
49+
:svd_vals!, :eig_vals!, :eigh_vals!,
50+
:project_hermitian!, :project_antihermitian!, :project_isometric!,
51+
)
4852
@eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm)
4953
foreachblock(t, N) do _, (tblock, Nblock)
5054
Nblock′ = $f!(tblock, Nblock, alg)
@@ -56,19 +60,6 @@ for f! in (:qr_null!, :lq_null!, :project_hermitian!, :project_antihermitian!, :
5660
end
5761
end
5862

59-
# Handle these separately because single output instead of tuple
60-
for f! in (:svd_vals!, :eig_vals!, :eigh_vals!)
61-
@eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm)
62-
foreachblock(t, N) do _, (tblock, Nblock)
63-
Nblock′ = $f!(tblock, diagview(Nblock), alg)
64-
# deal with the case where the output is not the same as the input
65-
diagview(Nblock) === Nblock′ || copy!(diagview(Nblock), Nblock′)
66-
return nothing
67-
end
68-
return N
69-
end
70-
end
71-
7263
# Singular value decomposition
7364
# ----------------------------
7465
function MAK.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm)
@@ -90,7 +81,8 @@ end
9081

9182
function MAK.initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
9283
V_cod = infimum(fuse(codomain(t)), fuse(domain(t)))
93-
return DiagonalTensorMap{real(scalartype(t))}(undef, V_cod)
84+
T = real(scalartype(t))
85+
return SectorVector{T}(undef, V_cod)
9486
end
9587

9688
# Eigenvalue decomposition
@@ -114,13 +106,13 @@ end
114106
function MAK.initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
115107
V_D = fuse(domain(t))
116108
T = real(scalartype(t))
117-
return D = DiagonalTensorMap{Tc}(undef, V_D)
109+
return SectorVector{T}(undef, V_D)
118110
end
119111

120112
function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
121113
V_D = fuse(domain(t))
122114
Tc = complex(scalartype(t))
123-
return D = DiagonalTensorMap{Tc}(undef, V_D)
115+
return SectorVector{Tc}(undef, V_D)
124116
end
125117

126118
# QR decomposition

src/factorizations/truncation.jl

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ end
7878
function MAK.truncate(
7979
::typeof(left_null!), (U, S)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy
8080
)
81-
extended_S = SectorDict(
82-
c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 1) - size(b, 2))))
83-
for (c, b) in blocks(S)
84-
)
81+
extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(codomain(U))))
82+
for (c, b) in blocks(S)
83+
copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter
84+
end
8585
ind = MAK.findtruncated(extended_S, strategy)
8686
V_truncated = truncate_space(space(S, 1), ind)
8787
= similar(U, codomain(U) V_truncated)
@@ -91,10 +91,10 @@ end
9191
function MAK.truncate(
9292
::typeof(right_null!), (S, Vᴴ)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy
9393
)
94-
extended_S = SectorDict(
95-
c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 2) - size(b, 1))))
96-
for (c, b) in blocks(S)
97-
)
94+
extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(domain(Vᴴ))))
95+
for (c, b) in blocks(S)
96+
copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter
97+
end
9898
ind = MAK.findtruncated(extended_S, strategy)
9999
V_truncated = truncate_space(dual(space(S, 2)), ind)
100100
Ṽᴴ = similar(Vᴴ, V_truncated domain(Vᴴ))
@@ -177,26 +177,40 @@ function _findnexttruncvalue(
177177
end
178178
end
179179

180+
function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false)
181+
values_sorted = similar(values)
182+
perms = SectorDict(
183+
(
184+
begin
185+
p = sortperm(v; by, rev)
186+
vs = values_sorted[c]
187+
vs .= view(v, p)
188+
c => p
189+
end
190+
) for (c, v) in pairs(values)
191+
)
192+
return values_sorted, perms
193+
end
194+
180195
# findtruncated
181196
# -------------
182197
# Generic fallback
183-
function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationStrategy)
198+
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationStrategy)
184199
return MAK.findtruncated(values, strategy)
185200
end
186201

187-
function MAK.findtruncated(values::SectorDict, ::NoTruncation)
188-
return SectorDict(c => Colon() for (c, b) in values)
202+
function MAK.findtruncated(values::SectorVector, ::NoTruncation)
203+
return SectorDict(c => Colon() for c in keys(values))
189204
end
190205

191-
function MAK.findtruncated(values::SectorDict, strategy::TruncationByOrder)
192-
perms = SectorDict(c => (sortperm(d; strategy.by, strategy.rev)) for (c, d) in values)
193-
values_sorted = SectorDict(c => d[perms[c]] for (c, d) in values)
206+
function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder)
207+
values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev)
194208
inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany))
195209
return SectorDict(c => perms[c][I] for (c, I) in inds)
196210
end
197-
function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByOrder)
211+
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder)
198212
I = keytype(values)
199-
truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in values)
213+
truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values))
200214
totaldim = sum(dim(c) * d for (c, d) in truncdim; init = 0)
201215
while totaldim > strategy.howmany
202216
next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev)
@@ -209,32 +223,31 @@ function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByOrder)
209223
return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim)
210224
end
211225

212-
function MAK.findtruncated(values::SectorDict, strategy::TruncationByFilter)
213-
return SectorDict(c => findall(strategy.filter, d) for (c, d) in values)
226+
function MAK.findtruncated(values::SectorVector, strategy::TruncationByFilter)
227+
return SectorDict(c => findall(strategy.filter, d) for (c, d) in pairs(values))
214228
end
215229

216-
function MAK.findtruncated(values::SectorDict, strategy::TruncationByValue)
230+
function MAK.findtruncated(values::SectorVector, strategy::TruncationByValue)
217231
atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
218232
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
219-
return SectorDict(c => MAK.findtruncated(d, strategy′) for (c, d) in values)
233+
return SectorDict(c => MAK.findtruncated(d, strategy′) for (c, d) in pairs(values))
220234
end
221-
function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByValue)
235+
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByValue)
222236
atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
223237
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
224-
return SectorDict(c => MAK.findtruncated_svd(d, strategy′) for (c, d) in values)
238+
return SectorDict(c => MAK.findtruncated_svd(d, strategy′) for (c, d) in pairs(values))
225239
end
226240

227-
function MAK.findtruncated(values::SectorDict, strategy::TruncationByError)
228-
perms = SectorDict(c => sortperm(d; by = abs, rev = true) for (c, d) in values)
229-
values_sorted = SectorDict(c => d[perms[c]] for (c, d) in Sd)
241+
function MAK.findtruncated(values::SectorVector, strategy::TruncationByError)
242+
values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev)
230243
inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany))
231244
return SectorDict(c => perms[c][I] for (c, I) in inds)
232245
end
233-
function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByError)
246+
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByError)
234247
I = keytype(values)
235-
truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in values)
248+
truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values))
236249
by(c, v) = abs(v)^strategy.p * dim(c)
237-
Nᵖ = sum(((c, v),) -> sum(Base.Fix1(by, c), v), values)
250+
Nᵖ = sum(((c, v),) -> sum(Base.Fix1(by, c), v), pairs(values))
238251
ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ)
239252
truncerrᵖ = zero(real(scalartype(valtype(values))))
240253
next = _findnexttruncvalue(values, truncdim)
@@ -248,16 +261,16 @@ function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByError)
248261
return SectorDict{I, Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim)
249262
end
250263

251-
function MAK.findtruncated(values::SectorDict, strategy::TruncationSpace)
264+
function MAK.findtruncated(values::SectorVector, strategy::TruncationSpace)
252265
blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev)
253266
return SectorDict(c => MAK.findtruncated(d, blockstrategy(c)) for (c, d) in values)
254267
end
255-
function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationSpace)
268+
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace)
256269
blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev)
257-
return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in values)
270+
return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values))
258271
end
259272

260-
function MAK.findtruncated(values::SectorDict, strategy::TruncationIntersection)
273+
function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersection)
261274
inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components)
262275
return SectorDict(
263276
c => mapreduce(
@@ -266,7 +279,7 @@ function MAK.findtruncated(values::SectorDict, strategy::TruncationIntersection)
266279
) for c in intersect(map(keys, inds)...)
267280
)
268281
end
269-
function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationIntersection)
282+
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationIntersection)
270283
inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components)
271284
return SectorDict(
272285
c => mapreduce(
@@ -278,13 +291,12 @@ end
278291

279292
# Truncation error
280293
# ----------------
281-
MAK.truncation_error(values::SectorDict, ind) =
282-
MAK.truncation_error!(SectorDict(c => copy(v) for (c, v) in values), ind)
294+
MAK.truncation_error(values::SectorVector, ind) = MAK.truncation_error!(copy(values), ind)
283295

284-
function MAK.truncation_error!(values::SectorDict, ind)
296+
function MAK.truncation_error!(values::SectorVector, ind)
285297
for (c, ind_c) in ind
286298
v = values[c]
287299
v[ind_c] .= zero(eltype(v))
288300
end
289-
return TensorKit._norm(values, 2, zero(real(eltype(valtype(values)))))
301+
return norm(values)
290302
end

src/tensors/blockiterator.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ for c in union(blocksectors.(ts)...)
3030
end
3131
```
3232
"""
33-
function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; scheduler = nothing)
33+
function foreachblock(f, t, ts...; scheduler = nothing)
3434
tensors = (t, ts...)
3535
allsectors = union(blocksectors.(tensors)...)
3636
foreach(allsectors) do c
3737
return f(c, block.(tensors, Ref(c)))
3838
end
3939
return nothing
4040
end
41-
function foreachblock(f, t::AbstractTensorMap; scheduler = nothing)
41+
function foreachblock(f, t; scheduler = nothing)
4242
foreach(blocks(t)) do (c, b)
4343
return f(c, (b,))
4444
end

src/tensors/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ function LinearAlgebra.rank(
300300
dim(t) == 0 && return r
301301
S = LinearAlgebra.svdvals(t)
302302
tol = max(atol, rtol * maximum(first, values(S)))
303-
for (c, b) in S
303+
for (c, b) in pairs(S)
304304
if !isempty(b)
305305
r += dim(c) * count(>(tol), b)
306306
end

0 commit comments

Comments
 (0)