Skip to content

Commit 6ed0960

Browse files
committed
add diagonal tensor tests
1 parent 3a1cf23 commit 6ed0960

File tree

1 file changed

+232
-0
lines changed

1 file changed

+232
-0
lines changed

test/test_A4.jl

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using MultiTensorKit
22
using TensorKitSectors, TensorKit
33
using Test, TestExtras
44
using Random
5+
using LinearAlgebra: LinearAlgebra
56

67
const MTK = MultiTensorKit
78
const TK = TensorKit
@@ -600,6 +601,237 @@ println("---------------------------------------")
600601
end
601602
end
602603

604+
println("-------------------------------------------")
605+
println("| Multifusion diagonal tensor tests |")
606+
println("-------------------------------------------")
607+
608+
V = Vect[I](values(I)[k] => 1 for k in 1:length(values(I)))
609+
610+
@timedtestset "DiagonalTensor" begin
611+
@timedtestset "Basic properties and algebra" begin
612+
for T in (Float32, Float64, ComplexF32, ComplexF64, BigFloat)
613+
# constructors
614+
t = @constinferred DiagonalTensorMap{T}(undef, V)
615+
t = @constinferred DiagonalTensorMap(rand(T, reduceddim(V)), V)
616+
t2 = @constinferred DiagonalTensorMap{T}(undef, space(t))
617+
@test space(t2) == space(t)
618+
@test_throws ArgumentError DiagonalTensorMap{T}(undef, V^2 V)
619+
t2 = @constinferred DiagonalTensorMap{T}(undef, domain(t))
620+
@test space(t2) == space(t)
621+
@test_throws ArgumentError DiagonalTensorMap{T}(undef, V^2)
622+
# properties
623+
@test @constinferred(hash(t)) == hash(deepcopy(t))
624+
@test scalartype(t) == T
625+
@test codomain(t) == ProductSpace(V)
626+
@test domain(t) == ProductSpace(V)
627+
@test space(t) == (V V)
628+
@test space(t') == (V V)
629+
@test dim(t) == dim(space(t))
630+
# blocks
631+
bs = @constinferred blocks(t)
632+
(c, b1), state = @constinferred Nothing iterate(bs)
633+
@test c == first(blocksectors(V V))
634+
next = @constinferred Nothing iterate(bs, state)
635+
b2 = @constinferred block(t, first(blocksectors(t)))
636+
@test b1 == b2
637+
@test eltype(bs) === Pair{typeof(c),typeof(b1)}
638+
@test typeof(b1) === TK.blocktype(t)
639+
# basic linear algebra
640+
@test isa(@constinferred(norm(t)), real(T))
641+
@test norm(t)^2 dot(t, t)
642+
α = rand(T)
643+
@test norm* t) abs(α) * norm(t)
644+
@test norm(t + t, 2) 2 * norm(t, 2)
645+
@test norm(t + t, 1) 2 * norm(t, 1)
646+
@test norm(t + t, Inf) 2 * norm(t, Inf)
647+
p = 3 * rand(Float64)
648+
@test norm(t + t, p) 2 * norm(t, p)
649+
@test norm(t) norm(t')
650+
651+
@test t == @constinferred(TensorMap(t))
652+
@test norm(t + TensorMap(t)) 2 * norm(t)
653+
654+
@test norm(zerovector!(t)) == 0
655+
@test norm(one!(t)) sqrt(dim(V))
656+
@test one!(t) == id(V)
657+
@test norm(one!(t) - id(V)) == 0
658+
659+
t1 = DiagonalTensorMap(rand(T, reduceddim(V)), V)
660+
t2 = DiagonalTensorMap(rand(T, reduceddim(V)), V)
661+
t3 = DiagonalTensorMap(rand(T, reduceddim(V)), V)
662+
α = rand(T)
663+
β = rand(T)
664+
@test @constinferred(dot(t1, t2)) conj(dot(t2, t1))
665+
@test dot(t2, t1) conj(dot(t2', t1'))
666+
@test dot(t3, α * t1 + β * t2) α * dot(t3, t1) + β * dot(t3, t2)
667+
end
668+
end
669+
670+
@timedtestset "Basic linear algebra: test via conversion" begin
671+
for T in (Float32, ComplexF64)
672+
t1 = DiagonalTensorMap(rand(T, reduceddim(V)), V)
673+
t2 = DiagonalTensorMap(rand(T, reduceddim(V)), V)
674+
@test norm(t1, 2) norm(convert(TensorMap, t1), 2)
675+
@test dot(t2, t1) dot(convert(TensorMap, t2), convert(TensorMap, t1))
676+
α = rand(T)
677+
@test convert(TensorMap, α * t1) α * convert(TensorMap, t1)
678+
@test convert(TensorMap, t1') convert(TensorMap, t1)'
679+
@test convert(TensorMap, t1 + t2)
680+
convert(TensorMap, t1) + convert(TensorMap, t2)
681+
end
682+
end
683+
@timedtestset "Real and imaginary parts" begin
684+
for T in (Float64, ComplexF64, ComplexF32)
685+
t = DiagonalTensorMap(rand(T, reduceddim(V)), V)
686+
687+
tr = @constinferred real(t)
688+
@test scalartype(tr) <: Real
689+
@test real(convert(TensorMap, t)) == convert(TensorMap, tr)
690+
691+
ti = @constinferred imag(t)
692+
@test scalartype(ti) <: Real
693+
@test imag(convert(TensorMap, t)) == convert(TensorMap, ti)
694+
695+
tc = @inferred complex(t)
696+
@test scalartype(tc) <: Complex
697+
@test complex(convert(TensorMap, t)) == convert(TensorMap, tc)
698+
699+
tc2 = @inferred complex(tr, ti)
700+
@test tc2 tc
701+
end
702+
end
703+
@timedtestset "Tensor conversion" begin
704+
t = @constinferred DiagonalTensorMap(undef, V)
705+
rand!(t.data)
706+
# element type conversion
707+
tc = complex(t)
708+
@test convert(typeof(tc), t) == tc
709+
@test typeof(convert(typeof(tc), t)) == typeof(tc)
710+
# to and from generic TensorMap
711+
td = DiagonalTensorMap(TensorMap(t))
712+
@test t == td
713+
@test typeof(td) == typeof(t)
714+
end
715+
@timedtestset "Trace, Multiplication and inverse" begin
716+
t1 = DiagonalTensorMap(rand(Float64, reduceddim(V)), V)
717+
t2 = DiagonalTensorMap(rand(ComplexF64, reduceddim(V)), V)
718+
@test tr(TensorMap(t1)) == @constinferred tr(t1)
719+
@test tr(TensorMap(t2)) == @constinferred tr(t2)
720+
@test TensorMap(@constinferred t1 * t2) TensorMap(t1) * TensorMap(t2)
721+
@test TensorMap(@constinferred t1 \ t2) TensorMap(t1) \ TensorMap(t2)
722+
@test TensorMap(@constinferred t1 / t2) TensorMap(t1) / TensorMap(t2)
723+
@test TensorMap(@constinferred inv(t1)) inv(TensorMap(t1))
724+
@test TensorMap(@constinferred pinv(t1)) pinv(TensorMap(t1))
725+
@test all(Base.Fix2(isa, DiagonalTensorMap),
726+
(t1 * t2, t1 \ t2, t1 / t2, inv(t1), pinv(t1)))
727+
# no V * V' * V ← V or V^2 ← V tests due to Nsymbol erroring where fusion is forbidden
728+
end
729+
@timedtestset "Tensor contraction " for i in 1:r
730+
W = Vect[I]((i, i, label) => 2 for label in 1:MTK._numlabels(I, i, i))
731+
732+
d = DiagonalTensorMap(rand(ComplexF64, reduceddim(W)), W)
733+
t = TensorMap(d)
734+
A = randn(ComplexF64, W W' W, W)
735+
B = randn(ComplexF64, W W' W, W W') # empty for modules so untested
736+
737+
@planar E1[-1 -2 -3; -4 -5] := B[-1 -2 -3; 1 -5] * d[1; -4]
738+
@planar E2[-1 -2 -3; -4 -5] := B[-1 -2 -3; 1 -5] * t[1; -4]
739+
@test E1 E2
740+
@planar E1[-1 -2 -3; -4 -5] = B[-1 -2 -3; -4 1] * d'[-5; 1]
741+
@planar E2[-1 -2 -3; -4 -5] = B[-1 -2 -3; -4 1] * t'[-5; 1]
742+
@test E1 E2
743+
@planar E1[-1 -2 -3; -4 -5] = B[1 -2 -3; -4 -5] * d[-1; 1]
744+
@planar E2[-1 -2 -3; -4 -5] = B[1 -2 -3; -4 -5] * t[-1; 1]
745+
@test E1 E2
746+
@planar E1[-1 -2 -3; -4 -5] = B[-1 1 -3; -4 -5] * d[1; -2]
747+
@planar E2[-1 -2 -3; -4 -5] = B[-1 1 -3; -4 -5] * t[1; -2]
748+
@test E1 E2
749+
@planar E1[-1 -2 -3; -4 -5] = B[-1 -2 1; -4 -5] * d'[-3; 1]
750+
@planar E2[-1 -2 -3; -4 -5] = B[-1 -2 1; -4 -5] * t'[-3; 1]
751+
@test E1 E2
752+
end
753+
@timedtestset "Factorization" begin
754+
for T in (Float32, ComplexF64)
755+
t = DiagonalTensorMap(rand(T, reduceddim(V)), V)
756+
@testset "eig" begin
757+
D, W = @constinferred eig(t)
758+
@test t * W W * D
759+
t2 = t + t'
760+
D2, V2 = @constinferred eigh(t2)
761+
VdV2 = V2' * V2
762+
@test VdV2 one(VdV2)
763+
@test t2 * V2 V2 * D2
764+
765+
@test rank(D) rank(t)
766+
@test cond(D) cond(t)
767+
@test all(((s, t),) -> isapprox(s, t),
768+
zip(values(LinearAlgebra.eigvals(D)),
769+
values(LinearAlgebra.eigvals(t))))
770+
end
771+
@testset "leftorth with $alg" for alg in (TK.QR(), TK.QL())
772+
Q, R = @constinferred leftorth(t; alg=alg)
773+
QdQ = Q' * Q
774+
@test QdQ one(QdQ)
775+
@test Q * R t
776+
if alg isa Polar
777+
@test isposdef(R)
778+
end
779+
end
780+
@testset "rightorth with $alg" for alg in (TK.RQ(), TK.LQ())
781+
L, Q = @constinferred rightorth(t; alg=alg)
782+
QQd = Q * Q'
783+
@test QQd one(QQd)
784+
@test L * Q t
785+
if alg isa Polar
786+
@test isposdef(L)
787+
end
788+
end
789+
@testset "tsvd with $alg" for alg in (TK.SVD(), TK.SDD())
790+
U, S, Vᴴ = @constinferred tsvd(t; alg=alg)
791+
UdU = U' * U
792+
@test UdU one(UdU)
793+
VdV = Vᴴ * Vᴴ'
794+
@test VdV one(VdV)
795+
@test U * S * Vᴴ t
796+
797+
@test rank(S) rank(t)
798+
@test cond(S) cond(t)
799+
@test all(((s, t),) -> isapprox(s, t),
800+
zip(values(LinearAlgebra.svdvals(S)),
801+
values(LinearAlgebra.svdvals(t))))
802+
end
803+
end
804+
end
805+
@timedtestset "Tensor functions" begin
806+
for T in (Float64, ComplexF64)
807+
d = DiagonalTensorMap(rand(T, reduceddim(V)), V)
808+
# rand is important for positive numbers in the real case, for log and sqrt
809+
t = TensorMap(d)
810+
@test @constinferred exp(d) exp(t)
811+
@test @constinferred log(d) log(t)
812+
@test @constinferred sqrt(d) sqrt(t)
813+
@test @constinferred sin(d) sin(t)
814+
@test @constinferred cos(d) cos(t)
815+
@test @constinferred tan(d) tan(t)
816+
@test @constinferred cot(d) cot(t)
817+
@test @constinferred sinh(d) sinh(t)
818+
@test @constinferred cosh(d) cosh(t)
819+
@test @constinferred tanh(d) tanh(t)
820+
@test @constinferred coth(d) coth(t)
821+
@test @constinferred asin(d) asin(t)
822+
@test @constinferred acos(d) acos(t)
823+
@test @constinferred atan(d) atan(t)
824+
@test @constinferred acot(d) acot(t)
825+
@test @constinferred asinh(d) asinh(t)
826+
@test @constinferred acosh(one(d) + d) acosh(one(t) + t)
827+
@test @constinferred atanh(d) atanh(t)
828+
@test @constinferred acoth(one(t) + d) acoth(one(d) + t)
829+
end
830+
end
831+
end
832+
833+
834+
603835
@testset "$Istr ($i, $j) left and right units" for i in 1:r, j in 1:r
604836
Cij_obs = I.(i, j, MTK._get_dual_cache(I)[2][i, j])
605837

0 commit comments

Comments
 (0)