Skip to content

Commit fb340d4

Browse files
authored
define matricize (#77)
1 parent da61225 commit fb340d4

File tree

9 files changed

+151
-60
lines changed

9 files changed

+151
-60
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.7"
4+
version = "0.5.8"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -28,7 +28,7 @@ LRUCache = "1.6"
2828
LinearAlgebra = "1.10"
2929
Random = "1.10"
3030
Strided = "2.3"
31-
TensorAlgebra = "0.3.8"
31+
TensorAlgebra = "0.3.12"
3232
TensorProducts = "0.1.7"
3333
TypeParameterAccessors = "0.4"
3434
WignerSymbols = "2.0.0"

src/fusiontensor/base_interface.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,13 @@ function Base.similar(::FusionTensor, ::Type{T}, new_axes::FusionTensorAxes) whe
122122
return FusionTensor{T}(undef, new_axes)
123123
end
124124

125-
Base.show(io::IO, ft::FusionTensor) = print(io, "$(ndims(ft))-dim FusionTensor")
125+
function Base.show(io::IO, ft::FusionTensor)
126+
return print(io, "$(ndims(ft))-dim FusionTensor with size $(size(ft))")
127+
end
126128

127129
function Base.show(io::IO, ::MIME"text/plain", ft::FusionTensor)
128-
print(io, "$(ndims(ft))-dim FusionTensor with axes:")
130+
print(io, ft)
131+
print(" and axes:")
129132
for ax in axes(ft)
130133
print(io, "\n", ax)
131134
end

src/fusiontensor/fusiontensor.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using GradedArrays:
2222
space_isequal
2323
using LinearAlgebra: UniformScaling
2424
using Random: Random, AbstractRNG, randn!
25-
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
25+
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar, length_codomain, length_domain
2626
using TensorProducts: tensor_product
2727
using TypeParameterAccessors: type_parameters
2828

@@ -134,6 +134,10 @@ struct FusionTensor{T,N,Axes<:FusionTensorAxes,Mat<:AbstractMatrix{T},Mapping} <
134134
end
135135
end
136136

137+
const FusionMatrix{T,Axes,Mat,Mapping} = FusionTensor{
138+
T,2,Axes,Mapping
139+
} where {BT<:BlockedTuple{2,(1, 1)},Axes<:FusionTensorAxes{BT}}
140+
137141
# ===================================== Accessors ========================================
138142

139143
data_matrix(ft::FusionTensor) = ft.data_matrix

src/fusiontensor/fusiontensoraxes.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@ using GradedArrays:
88
dual,
99
sector_type,
1010
trivial
11-
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTuple
11+
using TensorAlgebra:
12+
TensorAlgebra,
13+
AbstractBlockPermutation,
14+
AbstractBlockTuple,
15+
BlockedTuple,
16+
length_codomain,
17+
length_domain
1218
using TensorProducts:
1319
using TypeParameterAccessors: type_parameters
1420

@@ -65,6 +71,8 @@ TensorAlgebra.BlockedTuple(fta::FusionTensorAxes) = fta.outer_axes
6571

6672
TensorAlgebra.trivial_axis(fta::FusionTensorAxes) = trivial_axis(sector_type(fta))
6773

74+
TensorAlgebra.length_domain(fta::FusionTensorAxes) = length(domain(fta))
75+
6876
# ================================== Base interface ======================================
6977

7078
for f in [
@@ -140,7 +148,3 @@ function fused_domain(fta::FusionTensorAxes)
140148
end
141149
return dual((dual.(domain(fta))...))
142150
end
143-
144-
length_codomain(fta::FusionTensorAxes) = length(codomain(fta))
145-
146-
length_domain(fta::FusionTensorAxes) = length(domain(fta))

src/fusiontensor/linear_algebra_interface.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors
1010
# allow to contract with different eltype and let BlockSparseArray ensure compatibility
1111
# impose matching type and number of axes at compile time
1212
# impose matching axes at run time
13-
# TODO remove this once TensorAlgebra.contract can be used?
1413
function LinearAlgebra.mul!(
15-
C::FusionTensor, A::FusionTensor, B::FusionTensor, α::Number, β::Number
14+
C::FusionMatrix, A::FusionMatrix, B::FusionMatrix, α::Number, β::Number
1615
)
1716

1817
# compile time checks

src/fusiontensor/tensor_algebra_interface.jl

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,47 +4,70 @@ using LinearAlgebra: mul!
44

55
using BlockArrays: Block
66

7-
using TensorAlgebra: BlockedPermutation, Matricize, TensorAlgebra
7+
using GradedArrays: space_isequal
8+
using TensorAlgebra:
9+
TensorAlgebra,
10+
AbstractBlockPermutation,
11+
BlockedTrivialPermutation,
12+
BlockedTuple,
13+
FusionStyle,
14+
Matricize,
15+
blockedperm,
16+
genperm,
17+
unmatricize
818

9-
# TODO how to deal with inner contraction = no ouput axis?
10-
# => currently biperm_dest is a BlockedPermutation{0}, change this
11-
function TensorAlgebra.allocate_output(
19+
function TensorAlgebra.output_axes(
1220
::typeof(contract),
13-
biperm_dest::BlockedPermutation{2},
21+
biperm_dest::AbstractBlockPermutation{2},
1422
a1::FusionTensor,
15-
biperm1::BlockedPermutation{2},
23+
biperm1::AbstractBlockPermutation{2},
1624
a2::FusionTensor,
17-
biperm2::BlockedPermutation{2},
18-
α::Number=true,
25+
biperm2::AbstractBlockPermutation{2},
26+
α::Number=one(Bool),
1927
)
20-
axes_dest = (
21-
map(i -> axes(a1)[i], first(blocks(biperm1))),
22-
map(i -> axes(a2)[i], last(blocks(biperm2))),
28+
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
29+
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
30+
@assert all(space_isequal.(dual.(axes_contracted), axes_contracted2))
31+
flat_axes = genperm((axes_codomain..., axes_domain...), Tuple(biperm_dest))
32+
return FusionTensorAxes(
33+
tuplemortar((
34+
flat_axes[begin:length_codomain(biperm_dest)],
35+
flat_axes[(length_codomain(biperm_dest) + 1):end],
36+
)),
2337
)
24-
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
2538
end
2639

27-
# TBD do really I need to define these as I cannot use them in contract! and has to redefine it?
28-
#TensorAlgebra.fusedims(ft::FusionTensor, perm::BlockedPermutation{2}) = permutedims(ft, perm)
29-
#function TensorAlgebra.splitdims(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)
30-
#function TensorAlgebra.splitdims!(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)
31-
32-
# I cannot use contract! from TensorAlgebra/src/contract/contract_matricize/contract.jl
33-
# as it calls _mul!, which I should not overload.
34-
# TBD define fallback _mul!(::AbstractArray, ::AbstractArray, ::AbstractArray) in TensorAlgebra?
35-
function TensorAlgebra.contract!(
36-
::Matricize,
37-
a_dest::FusionTensor,
38-
::BlockedPermutation{2},
39-
a1::FusionTensor,
40-
biperm1::BlockedPermutation{2},
41-
a2::FusionTensor,
42-
biperm2::BlockedPermutation{2},
43-
α::Number,
44-
β::Number,
40+
struct FusionTensorFusionStyle <: FusionStyle end
41+
42+
TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle()
43+
44+
function TensorAlgebra.matricize(
45+
::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2}
46+
)
47+
blocklengths(biperm) == blocklengths(axes(ft)) ||
48+
throw(ArgumentError("Invalid trivial biperm"))
49+
return FusionTensor(data_matrix(ft), (codomain_axis(ft),), (domain_axis(ft),))
50+
end
51+
52+
function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes)
53+
return FusionTensor(data_matrix(m), blocked_axes)
54+
end
55+
56+
function TensorAlgebra.permuteblockeddims(
57+
ft::FusionTensor, biperm::AbstractBlockPermutation
4558
)
46-
a1_perm = permutedims(a1, biperm1)
47-
a2_perm = permutedims(a2, biperm2)
48-
mul!(a_dest, a1_perm, a2_perm, α, β)
59+
return permutedims(ft, biperm)
60+
end
61+
62+
function TensorAlgebra.permuteblockeddims!(
63+
a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation
64+
)
65+
return permutedims!(a, b, biperm)
66+
end
67+
68+
# TODO define custom broadcast rules
69+
function TensorAlgebra.unmatricize_add!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β)
70+
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
71+
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
4972
return a_dest
5073
end

test/test_basics.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Test: @test, @test_throws, @testset
33
using BlockArrays: Block
44
using BlockSparseArrays: BlockSparseArray, eachblockstoredindex
55
using FusionTensors:
6+
FusionMatrix,
67
FusionTensor,
78
FusionTensorAxes,
89
codomain_axes,
@@ -40,6 +41,7 @@ include("setup.jl")
4041
fta = FusionTensorAxes((g1,), (g2,))
4142
ft0 = FusionTensor{Float64}(undef, fta)
4243
@test ft0 isa FusionTensor
44+
@test ft0 isa FusionMatrix
4345
@test space_isequal(codomain_axis(ft0), g1)
4446
@test space_isequal(domain_axis(ft0), g2)
4547

@@ -134,6 +136,8 @@ end
134136
m2 = BlockSparseArray{Float64}(undef, gr, gc)
135137
ft = FusionTensor(m2, (g1, g2), (g3, g4))
136138

139+
@test ft isa FusionTensor
140+
@test !(ft isa FusionMatrix)
137141
@test data_matrix(ft) == m2
138142
@test checkspaces(codomain_axes(ft), (g1, g2))
139143
@test checkspaces(domain_axes(ft), (g3, g4))
@@ -155,6 +159,8 @@ end
155159

156160
# one row axis
157161
ft1 = FusionTensor{Float64}(undef, (g1,), ())
162+
@test ft1 isa FusionTensor
163+
@test !(ft1 isa FusionMatrix)
158164
@test ndims_codomain(ft1) == 1
159165
@test ndims_domain(ft1) == 0
160166
@test ndims(ft1) == 1
@@ -165,6 +171,8 @@ end
165171

166172
# one column axis
167173
ft2 = FusionTensor{Float64}(undef, (), (g1,))
174+
@test ft2 isa FusionTensor
175+
@test !(ft2 isa FusionMatrix)
168176
@test ndims_codomain(ft2) == 0
169177
@test ndims_domain(ft2) == 1
170178
@test ndims(ft2) == 1
@@ -175,13 +183,22 @@ end
175183

176184
# zero axis
177185
ft3 = FusionTensor{Float64}(undef, (), ())
186+
@test ft3 isa FusionTensor
187+
@test !(ft3 isa FusionMatrix)
178188
@test ndims_codomain(ft3) == 0
179189
@test ndims_domain(ft3) == 0
180190
@test ndims(ft3) == 0
181191
@test size(ft3) == tuplemortar(((), ()))
182192
@test size(data_matrix(ft3)) == (1, 1)
183193
@test isnothing(check_sanity(ft3))
184194
@test sector_type(ft3) === TrivialSector
195+
196+
ft4 = FusionTensor{Float64}(undef, (g1, g1), ())
197+
@test ft4 isa FusionTensor
198+
@test !(ft4 isa FusionMatrix)
199+
ft5 = FusionTensor{Float64}(undef, (), (g1, g1))
200+
@test ft5 isa FusionTensor
201+
@test !(ft5 isa FusionMatrix)
185202
end
186203

187204
@testset "specific constructors" begin

test/test_contraction.jl

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,37 @@ using LinearAlgebra: mul!
22
using Test: @test, @testset, @test_broken
33

44
using BlockSparseArrays: BlockSparseArray
5-
using FusionTensors: FusionTensor, domain_axes, codomain_axes
5+
using FusionTensors:
6+
FusionMatrix, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes
67
using GradedArrays: U1, dual, gradedrange
7-
using TensorAlgebra: contract, tuplemortar
8+
using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!
89

910
include("setup.jl")
1011

12+
@testset "matricize" begin
13+
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
14+
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
15+
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
16+
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
17+
18+
ft1 = randn(FusionTensorAxes((g1, g2), (dual(g3), dual(g4))))
19+
m = matricize(ft1, (1, 2), (3, 4))
20+
@test m isa FusionMatrix
21+
ft2 = unmatricize(m, axes(ft1))
22+
@test ft1 ft2
23+
24+
biperm = permmortar(((3,), (1, 2, 4)))
25+
m2 = matricize(ft1, biperm)
26+
ft_dest = FusionTensor{eltype(ft1)}(undef, axes(ft1)[biperm])
27+
unmatricize!(ft_dest, m2, permmortar(((1,), (2, 3, 4))))
28+
@test ft_dest permutedims(ft1, biperm)
29+
@test ft_dest permutedims(ft1, biperm)
30+
31+
ft2 = similar(ft1)
32+
unmatricize!(ft2, m2, biperm)
33+
@test ft1 ft2
34+
end
35+
1136
@testset "contraction" begin
1237
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
1338
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
@@ -26,15 +51,12 @@ include("setup.jl")
2651
@test codomain_axes(ft3) === codomain_axes(ft1)
2752

2853
# test LinearAlgebra.mul! with in-place matrix product
29-
mul!(ft3, ft1, ft2)
30-
@test isnothing(check_sanity(ft3))
31-
@test domain_axes(ft3) === domain_axes(ft2)
32-
@test codomain_axes(ft3) === codomain_axes(ft1)
54+
m1 = randn(FusionTensorAxes((g1,), (g2,)))
55+
m2 = randn(FusionTensorAxes((dual(g2),), (g3,)))
56+
m3 = FusionTensor{Float64}(undef, (g1,), (g3,))
3357

34-
mul!(ft3, ft1, ft2, 1.0, 1.0)
35-
@test isnothing(check_sanity(ft2))
36-
@test domain_axes(ft3) === domain_axes(ft2)
37-
@test codomain_axes(ft3) === codomain_axes(ft1)
58+
mul!(m3, m1, m2, 2.0, 0.0)
59+
@test m3 2m1 * m2
3860
end
3961

4062
@testset "TensorAlgebra interface" begin
@@ -43,19 +65,22 @@ end
4365
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
4466
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
4567

46-
ft1 = FusionTensor{Float64}(undef, (g1, g2), (g3, g4))
47-
ft2 = FusionTensor{Float64}(undef, dual.((g3, g4)), (dual(g1),))
48-
ft3 = FusionTensor{Float64}(undef, dual.((g3, g4)), dual.((g1, g2)))
68+
ft1 = randn(FusionTensorAxes((g1, g2), (g3, g4)))
69+
ft2 = randn(FusionTensorAxes(dual.((g3, g4)), (dual(g1),)))
70+
ft3 = randn(FusionTensorAxes(dual.((g3, g4)), dual.((g1, g2))))
4971

5072
ft4, legs = contract(ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
5173
@test legs == tuplemortar(((1, 2), (5,)))
5274
@test isnothing(check_sanity(ft4))
5375
@test domain_axes(ft4) === domain_axes(ft2)
5476
@test codomain_axes(ft4) === codomain_axes(ft1)
77+
@test ft4 ft1 * ft2
5578

5679
ft5 = contract((1, 2, 5), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
5780
@test isnothing(check_sanity(ft5))
58-
@test ft4 ft5
81+
@test ndims_codomain(ft5) === 3
82+
@test ndims_domain(ft5) === 0
83+
@test permutedims(ft5, (1, 2), (3,)) ft4
5984

6085
ft6 = contract(tuplemortar(((1, 2), (5,))), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
6186
@test isnothing(check_sanity(ft6))
@@ -66,4 +91,18 @@ end
6691
ft7, legs = contract(ft1, (1, 2, 3, 4), ft3, (3, 4, 1, 2))
6792
@test legs == tuplemortar(((), ()))
6893
@test ft7 isa FusionTensor{Float64,0}
94+
95+
# include permutations
96+
ft6 = contract(tuplemortar(((5, 1), (2,))), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
97+
@test isnothing(check_sanity(ft6))
98+
@test permutedims(ft6, (2, 3), (1,)) ft4
99+
100+
ft8 = contract(
101+
tuplemortar(((-3,), (-1, -2, -4))), ft1, (-1, 1, -2, 2), ft3, (-3, 2, -4, 1)
102+
)
103+
left = permutedims(ft1, (1, 3), (2, 4))
104+
right = permutedims(ft3, (4, 2), (1, 3))
105+
lrprod = left * right
106+
newft = permutedims(lrprod, (3,), (1, 2, 4))
107+
@test newft ft8
69108
end

test/test_fusiontensoraxes.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Test: @test, @test_throws, @testset
22

33
using TensorProducts:
44
using BlockArrays: Block, blockedrange, blocklength, blocklengths, blocks
5-
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
5+
using TensorAlgebra: BlockedTuple, length_codomain, trivial_axis, tuplemortar
66

77
using FusionTensors:
88
FusionTensorAxes,
@@ -75,6 +75,7 @@ end
7575
@test blocklength(fta) == 2
7676
@test blocklengths(fta) == (2, 2)
7777
@test blocks(fta) == blocks(bt)
78+
@test length_codomain(fta) == 2
7879

7980
@test sector_type(fta) === sector_type(g2)
8081
@test length(codomain(fta)) == 2
@@ -107,6 +108,7 @@ end
107108
@test blocklength(fta) == 2
108109
@test blocklengths(fta) == (0, 0)
109110
@test sector_type(fta) == TrivialSector
111+
@test length_codomain(fta) == 0
110112

111113
@test codomain(fta) == ()
112114
@test space_isequal(fused_codomain(fta), trivial_axis(TrivialSector))

0 commit comments

Comments
 (0)