Skip to content

Commit 29e58ad

Browse files
committed
Updates for Diagonal and rectangular array tests
1 parent 7e02724 commit 29e58ad

File tree

4 files changed

+60
-19
lines changed

4 files changed

+60
-19
lines changed

src/tensors/factorizations/diagonal.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,34 @@ for f in (:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full,
88
@eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d)
99
end
1010

11-
for f! in (:qr_full!, :qr_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!)
11+
for f! in (:eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!)
1212
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap,
1313
::DiagonalAlgorithm)
1414
return d, similar(d)
1515
end
1616
end
17+
18+
for f! in (:qr_full!, :qr_compact!)
19+
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap,
20+
::DiagonalAlgorithm)
21+
return d, similar(d)
22+
end
23+
# to avoid ambiguities
24+
@eval function initialize_output(::typeof($f!), d::AdjointTensorMap,
25+
::DiagonalAlgorithm)
26+
return d, similar(d)
27+
end
28+
end
1729
for f! in (:lq_full!, :lq_compact!)
1830
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap,
1931
::DiagonalAlgorithm)
2032
return similar(d), d
2133
end
34+
# to avoid ambiguities
35+
@eval function initialize_output(::typeof($f!), d::AdjointTensorMap,
36+
::DiagonalAlgorithm)
37+
return similar(d), d
38+
end
2239
end
2340

2441
for f! in

src/tensors/factorizations/factorizations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}
6161
#------------------------------#
6262
# LinearAlgebra overloads
6363
#------------------------------#
64-
LinearAlgebra.svdvals(t::AbstractTensorMap) = diagview(svd_vals(t))
64+
#LinearAlgebra.svdvals(t::AbstractTensorMap) = diagview(svd_vals(t))
6565
LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t))
66-
LinearAlgebra.eigvals(t::AbstractTensorMap) = diagview(eigvals(t))
67-
LinearAlgebra.eigvals!(t::AbstractTensorMap) = diagview(eigvals!(t))
66+
#LinearAlgebra.eigvals(t::AbstractTensorMap) = diagview(eig_vals(t))
67+
LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = diagview(eig_vals!(t))
6868

6969
#--------------------------------------------------#
7070
# Checks for hermiticity and positive definiteness #

src/tensors/factorizations/matrixalgebrakit.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ for f! in (:qr_compact!, :qr_full!,
4545
end
4646

4747
# Handle these separately because single output instead of tuple
48-
for f! in (:qr_null!, :lq_null!, :svd_vals!, :eig_vals!, :eigh_vals!)
48+
for f! in (:qr_null!, :lq_null!)
4949
@eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm)
5050
check_input($f!, t, N, alg)
5151

@@ -60,6 +60,22 @@ for f! in (:qr_null!, :lq_null!, :svd_vals!, :eig_vals!, :eigh_vals!)
6060
end
6161
end
6262

63+
# Handle these separately because single output instead of tuple
64+
for f! in (:svd_vals!, :eig_vals!, :eigh_vals!)
65+
@eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm)
66+
check_input($f!, t, N, alg)
67+
68+
foreachblock(t, N) do _, (b, n)
69+
n′ = $f!(b, n.diag, alg)
70+
# deal with the case where the output is not the same as the input
71+
n.diag === n′ || copyto!(n, diagview(n′))
72+
return nothing
73+
end
74+
75+
return N
76+
end
77+
end
78+
6379
# Singular value decomposition
6480
# ----------------------------
6581
const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap}
@@ -101,11 +117,19 @@ end
101117
function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict,
102118
::AbstractAlgorithm)
103119
@check_scalar S t real
104-
V_cod = infimum(fuse(codomain(t)), fuse(domain(t)))
120+
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
105121
@check_space(S, V_cod V_dom)
106122
return nothing
107123
end
108124

125+
function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D::DiagonalTensorMap,
126+
::AbstractAlgorithm)
127+
@check_scalar D t real
128+
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
129+
@check_space(D, V_cod V_dom)
130+
return nothing
131+
end
132+
109133
function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm)
110134
V_cod = fuse(codomain(t))
111135
V_dom = fuse(domain(t))

test/factorizations.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ for V in spacelist
2727
W = V1 V2
2828
@testset for T in (Float32, ComplexF64)
2929
# Test both a normal tensor and an adjoint one.
30-
ts = (rand(T, W, W'), rand(T, W, W')')
30+
ts = (rand(T, W, W'), rand(T, W, W')', rand(T, V1, W'), rand(T, V1, W')')
3131
@testset for t in ts
32-
# test squares and rectangles here
3332
@testset "leftorth with $alg" for alg in
3433
(TensorKit.LAPACK_HouseholderQR(),
3534
TensorKit.LAPACK_HouseholderQR(;
@@ -40,10 +39,10 @@ for V in spacelist
4039
TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()),
4140
TensorKit.LAPACK_QRIteration(),
4241
TensorKit.LAPACK_DivideAndConquer())
42+
(codomain(t) domain(t)) && alg isa TensorKit.PolarViaSVD && continue
4343
Q, R = @constinferred leftorth(t; alg=alg)
4444
@test isisometry(Q)
45-
tQR = Q * R
46-
@test tQR t
45+
@test Q * R t
4746
end
4847
@testset "leftnull with $alg" for alg in
4948
(TensorKit.LAPACK_HouseholderQR(),
@@ -61,6 +60,7 @@ for V in spacelist
6160
TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()),
6261
TensorKit.LAPACK_QRIteration(),
6362
TensorKit.LAPACK_DivideAndConquer())
63+
(domain(t) codomain(t)) && alg isa TensorKit.PolarViaSVD && continue
6464
L, Q = @constinferred rightorth(t; alg=alg)
6565
@test isisometry(Q; side=:right)
6666
@test L * Q t
@@ -80,28 +80,28 @@ for V in spacelist
8080
@test isisometry(V; side=:right)
8181
@test U * S * V t
8282

83-
s = LinearAlgebra.svdvals(t)
83+
s = LinearAlgebra.svdvals(t)
8484
s′ = LinearAlgebra.diag(S)
8585
for (c, b) in s
8686
@test b s′[c]
8787
end
88-
s = LinearAlgebra.svdvals(t')
88+
s = LinearAlgebra.svdvals(t')
8989
s′ = LinearAlgebra.diag(S')
9090
for (c, b) in s
9191
@test b s′[c]
9292
end
9393
end
9494
@testset "cond and rank" begin
95-
d1 = dim(codomain(t))
96-
d2 = dim(domain(t))
95+
d1 = dim(codomain(t))
96+
d2 = dim(domain(t))
9797
@test rank(t) == min(d1, d2)
98-
M = leftnull(t)
99-
@test rank(M) == max(d1, d2) - min(d1, d2)
100-
t3 = unitary(T, V1 V2, V1 V2)
98+
M = leftnull(t)
99+
@test rank(M) + rank(t) == d1
100+
t3 = unitary(T, V1 V2, V1 V2)
101101
@test cond(t3) one(real(T))
102102
@test rank(t3) == dim(V1 V2)
103-
t4 = randn(T, V1 V2, V1 V2)
104-
t4 = (t4 + t4') / 2
103+
t4 = randn(T, V1 V2, V1 V2)
104+
t4 = (t4 + t4') / 2
105105
vals = LinearAlgebra.eigvals(t4)
106106
λmax = maximum(s -> maximum(abs, s), values(vals))
107107
λmin = minimum(s -> minimum(abs, s), values(vals))

0 commit comments

Comments
 (0)