@@ -2,12 +2,37 @@ using LinearAlgebra: mul!
22using Test: @test , @testset , @test_broken
33
44using BlockSparseArrays: BlockSparseArray
5- using FusionTensors: FusionTensor, domain_axes, codomain_axes
5+ using FusionTensors:
6+ FusionMatrix, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes
67using GradedArrays: U1, dual, gradedrange
7- using TensorAlgebra: contract, tuplemortar
8+ using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!
89
910include (" setup.jl" )
1011
12+ @testset " matricize" begin
13+ g1 = gradedrange ([U1 (0 ) => 1 , U1 (1 ) => 2 , U1 (2 ) => 3 ])
14+ g2 = gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 2 , U1 (3 ) => 1 ])
15+ g3 = gradedrange ([U1 (- 1 ) => 1 , U1 (0 ) => 2 , U1 (1 ) => 1 ])
16+ g4 = gradedrange ([U1 (- 1 ) => 1 , U1 (0 ) => 1 , U1 (1 ) => 1 ])
17+
18+ ft1 = randn (FusionTensorAxes ((g1, g2), (dual (g3), dual (g4))))
19+ m = matricize (ft1, (1 , 2 ), (3 , 4 ))
20+ @test m isa FusionMatrix
21+ ft2 = unmatricize (m, axes (ft1))
22+ @test ft1 ≈ ft2
23+
24+ biperm = permmortar (((3 ,), (1 , 2 , 4 )))
25+ m2 = matricize (ft1, biperm)
26+ ft_dest = FusionTensor {eltype(ft1)} (undef, axes (ft1)[biperm])
27+ unmatricize! (ft_dest, m2, permmortar (((1 ,), (2 , 3 , 4 ))))
28+ @test ft_dest ≈ permutedims (ft1, biperm)
29+ @test ft_dest ≈ permutedims (ft1, biperm)
30+
31+ ft2 = similar (ft1)
32+ unmatricize! (ft2, m2, biperm)
33+ @test ft1 ≈ ft2
34+ end
35+
1136@testset " contraction" begin
1237 g1 = gradedrange ([U1 (0 ) => 1 , U1 (1 ) => 2 , U1 (2 ) => 3 ])
1338 g2 = gradedrange ([U1 (0 ) => 2 , U1 (1 ) => 2 , U1 (3 ) => 1 ])
@@ -26,15 +51,12 @@ include("setup.jl")
2651 @test codomain_axes (ft3) === codomain_axes (ft1)
2752
2853 # test LinearAlgebra.mul! with in-place matrix product
29- mul! (ft3, ft1, ft2)
30- @test isnothing (check_sanity (ft3))
31- @test domain_axes (ft3) === domain_axes (ft2)
32- @test codomain_axes (ft3) === codomain_axes (ft1)
54+ m1 = randn (FusionTensorAxes ((g1,), (g2,)))
55+ m2 = randn (FusionTensorAxes ((dual (g2),), (g3,)))
56+ m3 = FusionTensor {Float64} (undef, (g1,), (g3,))
3357
34- mul! (ft3, ft1, ft2, 1.0 , 1.0 )
35- @test isnothing (check_sanity (ft2))
36- @test domain_axes (ft3) === domain_axes (ft2)
37- @test codomain_axes (ft3) === codomain_axes (ft1)
58+ mul! (m3, m1, m2, 2.0 , 0.0 )
59+ @test m3 ≈ 2 m1 * m2
3860end
3961
4062@testset " TensorAlgebra interface" begin
4365 g3 = gradedrange ([U1 (- 1 ) => 1 , U1 (0 ) => 2 , U1 (1 ) => 1 ])
4466 g4 = gradedrange ([U1 (- 1 ) => 1 , U1 (0 ) => 1 , U1 (1 ) => 1 ])
4567
46- ft1 = FusionTensor {Float64} (undef, ( g1, g2), (g3, g4))
47- ft2 = FusionTensor {Float64} (undef, dual .((g3, g4)), (dual (g1),))
48- ft3 = FusionTensor {Float64} (undef, dual .((g3, g4)), dual .((g1, g2)))
68+ ft1 = randn ( FusionTensorAxes (( g1, g2), (g3, g4) ))
69+ ft2 = randn ( FusionTensorAxes ( dual .((g3, g4)), (dual (g1),) ))
70+ ft3 = randn ( FusionTensorAxes ( dual .((g3, g4)), dual .((g1, g2) )))
4971
5072 ft4, legs = contract (ft1, (1 , 2 , 3 , 4 ), ft2, (3 , 4 , 5 ))
5173 @test legs == tuplemortar (((1 , 2 ), (5 ,)))
5274 @test isnothing (check_sanity (ft4))
5375 @test domain_axes (ft4) === domain_axes (ft2)
5476 @test codomain_axes (ft4) === codomain_axes (ft1)
77+ @test ft4 ≈ ft1 * ft2
5578
5679 ft5 = contract ((1 , 2 , 5 ), ft1, (1 , 2 , 3 , 4 ), ft2, (3 , 4 , 5 ))
5780 @test isnothing (check_sanity (ft5))
58- @test ft4 ≈ ft5
81+ @test ndims_codomain (ft5) === 3
82+ @test ndims_domain (ft5) === 0
83+ @test permutedims (ft5, (1 , 2 ), (3 ,)) ≈ ft4
5984
6085 ft6 = contract (tuplemortar (((1 , 2 ), (5 ,))), ft1, (1 , 2 , 3 , 4 ), ft2, (3 , 4 , 5 ))
6186 @test isnothing (check_sanity (ft6))
6691 ft7, legs = contract (ft1, (1 , 2 , 3 , 4 ), ft3, (3 , 4 , 1 , 2 ))
6792 @test legs == tuplemortar (((), ()))
6893 @test ft7 isa FusionTensor{Float64,0 }
94+
95+ # include permutations
96+ ft6 = contract (tuplemortar (((5 , 1 ), (2 ,))), ft1, (1 , 2 , 3 , 4 ), ft2, (3 , 4 , 5 ))
97+ @test isnothing (check_sanity (ft6))
98+ @test permutedims (ft6, (2 , 3 ), (1 ,)) ≈ ft4
99+
100+ ft8 = contract (
101+ tuplemortar (((- 3 ,), (- 1 , - 2 , - 4 ))), ft1, (- 1 , 1 , - 2 , 2 ), ft3, (- 3 , 2 , - 4 , 1 )
102+ )
103+ left = permutedims (ft1, (1 , 3 ), (2 , 4 ))
104+ right = permutedims (ft3, (4 , 2 ), (1 , 3 ))
105+ lrprod = left * right
106+ newft = permutedims (lrprod, (3 ,), (1 , 2 , 4 ))
107+ @test newft ≈ ft8
69108end
0 commit comments