Skip to content

Commit 8d63119

Browse files
lkdvoskshyatt
andauthored
svd_vals(::DiagonalTensorMap) should return a SectorVector (#333)
* use `SectorVector` as output for `svd_vals(::DiagonalTensorMap)` * Don't sort diagonal svd_vals between blocks * update tests * update changelog * be careful about not mixing eigenvalues through blocks * LA svd comment * Don't overload LinearAlgebra's svd for now * Fix Changelog --------- Co-authored-by: Katharine Hyatt <[email protected]>
1 parent 155aa89 commit 8d63119

File tree

4 files changed

+34
-48
lines changed

4 files changed

+34
-48
lines changed

docs/src/Changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ When releasing a new version, move the "Unreleased" changes to a new version sec
2424

2525
- Extended support for selecting storage types in the `TensorMap` constructors ([#327](https://github.com/QuantumKitHub/TensorKit.jl/pull/327))
2626
- `similar_diagonal` to handle storage types when constructing diagonals ([#330](https://github.com/QuantumKitHub/TensorKit.jl/pull/330))
27-
- `LinearAlgebra.svd` overloads
2827

2928
### Fixed
3029

3130
- Issue with using relative tolerances in truncation schemes ([#314](https://github.com/QuantumKitHub/TensorKit.jl/issues/314))
3231
- Using `scalartype` instead of `eltype` in BLAS contraction ([#326](https://github.com/QuantumKitHub/TensorKit.jl/pull/326))
3332
- Divide by zero error in `show` for empty tensors ([#329](https://github.com/QuantumKitHub/TensorKit.jl/pull/329))
33+
- `svd_vals(::DiagonalTensorMap)` correctly outputs `SectorVector` and implementation fix. ([#333](https://github.com/QuantumKitHub/TensorKit.jl/pull/329))
3434

3535
## [0.16.0](https://github.com/QuantumKitHub/TensorKit.jl/releases/tag/v0.16.0) - 2025-12-08
3636

src/factorizations/diagonal.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# DiagonalTensorMap
22
# -----------------
33
_repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data)
4+
_repack_diagonal(d::SectorVector) = Diagonal(parent(d))
45

56
MAK.diagview(t::DiagonalTensorMap) = SectorVector(t.data, TensorKit.diagonalblockstructure(space(t)))
67

@@ -93,17 +94,10 @@ function MAK.svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm)
9394
return svd_full!(t, USVᴴ, alg)
9495
end
9596

96-
# f_vals
97-
# ------
98-
for f! in (:eig_vals!, :eigh_vals!, :svd_vals!)
99-
@eval function MAK.$f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm)
100-
$f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg)
101-
return V
102-
end
103-
@eval function MAK.initialize_output(
104-
::typeof($f!), d::DiagonalTensorMap, alg::DiagonalAlgorithm
105-
)
106-
data = MAK.initialize_output($f!, _repack_diagonal(d), alg)
107-
return DiagonalTensorMap(data, d.domain)
108-
end
97+
# For diagonal inputs we don't have to promote the scalartype since we know they are symmetric
98+
function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::DiagonalAlgorithm)
99+
V_D = fuse(domain(t))
100+
Tc = scalartype(t)
101+
A = similarstoragetype(t, Tc)
102+
return SectorVector{Tc, sectortype(t), A}(undef, V_D)
109103
end

src/factorizations/factorizations.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
4949
end
5050
LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = eig_vals!(t)
5151

52-
LinearAlgebra.svd(t::AbstractTensorMap; full::Bool = false) =
53-
full ? svd_full(t) : svd_compact(t)
54-
LinearAlgebra.svd!(t::AbstractTensorMap; full::Bool = false) =
55-
full ? svd_full!(t) : svd_compact!(t)
56-
5752
function LinearAlgebra.svdvals(t::AbstractTensorMap)
5853
tcopy = copy_oftype(t, factorisation_scalartype(svd_vals!, t))
5954
return LinearAlgebra.svdvals!(tcopy)

test/tensors/factorizations.jl

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test, TestExtras
22
using TensorKit
33
using LinearAlgebra: LinearAlgebra
4+
using MatrixAlgebraKit: diagview
45

56
@isdefined(TestSetup) || include("../setup.jl")
67
using .TestSetup
@@ -191,10 +192,9 @@ for V in spacelist
191192
@test isposdef(s)
192193
@test isisometric(vᴴ; side = :right)
193194

194-
s′ = LinearAlgebra.diag(s)
195-
for (c, b) in pairs(LinearAlgebra.svdvals(t))
196-
@test b s′[c]
197-
end
195+
s′ = @constinferred svd_vals(t)
196+
@test s′ diagview(s)
197+
@test s′ isa TensorKit.SectorVector
198198

199199
v, c = @constinferred left_orth(t; alg = :svd)
200200
@test v * c t
@@ -261,14 +261,14 @@ for V in spacelist
261261
@test norm(t - U1 * S1 * Vᴴ1) ϵ1 atol = eps(real(T))^(4 / 5)
262262
@test dim(domain(S1)) <= nvals
263263

264-
λ = minimum(minimum, values(LinearAlgebra.diag(S1)))
264+
λ = minimum(diagview(S1))
265265
trunc = trunctol(; atol = λ - 10eps(λ))
266266
U2, S2, Vᴴ2, ϵ2 = @constinferred svd_trunc(t; trunc)
267267
@test t * Vᴴ2' U2 * S2
268268
@test isisometric(U2)
269269
@test isisometric(Vᴴ2; side = :right)
270270
@test norm(t - U2 * S2 * Vᴴ2) ϵ2 atol = eps(real(T))^(4 / 5)
271-
@test minimum(minimum, values(LinearAlgebra.diag(S1))) >= λ
271+
@test minimum(diagview(S1)) >= λ
272272
@test U2 U1
273273
@test S2 S1
274274
@test Vᴴ2 Vᴴ1
@@ -297,7 +297,7 @@ for V in spacelist
297297
@test isisometric(U5)
298298
@test isisometric(Vᴴ5; side = :right)
299299
@test norm(t - U5 * S5 * Vᴴ5) ϵ5 atol = eps(real(T))^(4 / 5)
300-
@test minimum(minimum, values(LinearAlgebra.diag(S5))) >= λ
300+
@test minimum(diagview(S5)) >= λ
301301
@test dim(domain(S5)) nvals
302302
end
303303
end
@@ -312,13 +312,11 @@ for V in spacelist
312312
d, v = @constinferred eig_full(t)
313313
@test t * v v * d
314314

315-
d′ = LinearAlgebra.diag(d)
316-
for (c, b) in pairs(LinearAlgebra.eigvals(t))
317-
@test sort(b; by = abs) sort(d′[c]; by = abs)
318-
end
315+
d′ = @constinferred eig_vals(t)
316+
@test d′ diagview(d)
317+
@test d′ isa TensorKit.SectorVector
319318

320-
vdv = v' * v
321-
vdv = (vdv + vdv') / 2
319+
vdv = project_hermitian!(v' * v)
322320
@test @constinferred isposdef(vdv)
323321
t isa DiagonalTensorMap || @test !isposdef(t) # unlikely for non-hermitian map
324322

@@ -327,35 +325,34 @@ for V in spacelist
327325
@test t * v v * d
328326
@test dim(domain(d)) nvals
329327

330-
t2 = (t + t')
328+
t2 = @constinferred project_hermitian(t)
331329
D, V = eigen(t2)
332330
@test isisometric(V)
333331
D̃, Ṽ = @constinferred eigh_full(t2)
334332
@test D
335333
@test V
336-
λ = minimum(
337-
minimum(real(LinearAlgebra.diag(b)))
338-
for (c, b) in blocks(D)
339-
)
334+
λ = minimum(real, diagview(D))
340335
@test cond(Ṽ) one(real(T))
341336
@test isposdef(t2) == isposdef(λ)
342337
@test isposdef(t2 - λ * one(t2) + 0.1 * one(t2))
343338
@test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2))
344339

345-
add!(t, t')
346-
347-
d, v = @constinferred eigh_full(t)
348-
@test t * v v * d
340+
d, v = @constinferred eigh_full(t2)
341+
@test t2 * v v * d
349342
@test isunitary(v)
350343

351-
λ = minimum(minimum(real(LinearAlgebra.diag(b))) for (c, b) in blocks(d))
344+
d′ = @constinferred eigh_vals(t2)
345+
@test d′ diagview(d)
346+
@test d′ isa TensorKit.SectorVector
347+
348+
λ = minimum(real, diagview(d))
352349
@test cond(v) one(real(T))
353-
@test isposdef(t) == isposdef(λ)
354-
@test isposdef(t - λ * one(t) + 0.1 * one(t))
355-
@test !isposdef(t - λ * one(t) - 0.1 * one(t))
350+
@test isposdef(t2) == isposdef(λ)
351+
@test isposdef(t2 - λ * one(t) + 0.1 * one(t2))
352+
@test !isposdef(t2 - λ * one(t) - 0.1 * one(t2))
356353

357-
d, v = @constinferred eigh_trunc(t; trunc = truncrank(nvals))
358-
@test t * v v * d
354+
d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals))
355+
@test t2 * v v * d
359356
@test dim(domain(d)) nvals
360357
end
361358
end
@@ -390,7 +387,7 @@ for V in spacelist
390387
@test cond(t2) == 0.0
391388
end
392389
for T in eltypes, t in (rand(T, W, W), rand(T, W, W)')
393-
add!(t, t')
390+
project_hermitian!(t)
394391
vals = @constinferred LinearAlgebra.eigvals(t)
395392
λmax = maximum(s -> maximum(abs, s), values(vals))
396393
λmin = minimum(s -> minimum(abs, s), values(vals))

0 commit comments

Comments
 (0)