@@ -19,136 +19,32 @@ using GradedArrays:
19
19
sectormergesort,
20
20
sectors,
21
21
space_isequal
22
- using TensorAlgebra: BlockedTuple, tuplemortar
22
+ using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
23
23
using TensorProducts: tensor_product
24
24
using TypeParameterAccessors: type_parameters
25
25
26
- struct FusionTensor{T,N,Axes,Mat<: AbstractMatrix{T} ,Mapping} <: AbstractArray{T,N}
27
- data_matrix:: Mat
28
- axes:: Axes
29
- trees_block_mapping:: Mapping
30
-
31
- # inner constructor to impose constraints on types
32
- function FusionTensor {T,N,Axes,Mat,Mapping} (
33
- mat, legs, trees_block_mapping
34
- ) where {T,N,Axes,Mat,Mapping}
35
- S = isempty (legs) ? TrivialSector : sector_type (first (legs))
36
- @assert keytype (trees_block_mapping) < :
37
- Tuple{<: SectorFusionTree{S} ,<: SectorFusionTree{S} }
38
- @assert all (sector_type .(Tuple (legs)) .=== S)
39
- return new {T,N,Axes,Mat,Mapping} (mat, legs, trees_block_mapping)
40
- end
41
- end
42
-
43
- function FusionTensor (
44
- mat:: AbstractMatrix ,
45
- legs:: BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractGradedUnitRange}}} ,
46
- trees_block_mapping:: Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}} ,
47
- )
48
- return FusionTensor{
49
- eltype (mat),length (legs),typeof (legs),typeof (mat),typeof (trees_block_mapping)
50
- }(
51
- mat, legs, trees_block_mapping
52
- )
53
- end
54
-
55
- # getters
56
- data_matrix (ft:: FusionTensor ) = ft. data_matrix
57
- trees_block_mapping (ft:: FusionTensor ) = ft. trees_block_mapping
58
-
59
- # misc access
60
- codomain_axes (ft:: FusionTensor ) = axes (ft)[Block (1 )]
61
- domain_axes (ft:: FusionTensor ) = axes (ft)[Block (2 )]
62
- ndims_codomain (ft:: FusionTensor ) = length (codomain_axes (ft))
63
- ndims_domain (ft:: FusionTensor ) = length (domain_axes (ft))
64
-
65
- dummy_axis (ft:: FusionTensor ) = dummy_axis (sector_type (ft))
66
- dummy_axis (:: Type{S} ) where {S<: AbstractSector } = gradedrange ([trivial (S) => 1 ])
67
-
68
- function codomain_axis (ft:: FusionTensor )
69
- if ndims_codomain (ft) == 0
70
- return dummy_axis (ft)
71
- end
72
- return ⊗ (codomain_axes (ft)... )
73
- end
74
- function domain_axis (ft:: FusionTensor )
75
- if ndims_domain (ft) == 0
76
- return dummy_axis (ft)
77
- end
78
- return dual (⊗ (dual .(domain_axes (ft))... ))
79
- end
80
- function charge_block_size (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
81
- b = Tuple (findblock (ft, f1, f2))
82
- return ntuple (i -> sector_multiplicity (axes (ft, i)[b[i]]), ndims (ft))
83
- end
26
+ # ======================================= Misc ===========================================
84
27
85
- # GradedArrays interface
86
- function GradedArrays. sector_type (
87
- :: Type {<: FusionTensor{<:Any,<:Any,<:Any,<:Any,<:Dict{<:Tuple{<:Any,F}}} }
88
- ) where {F}
89
- return sector_type (F)
90
- end
91
-
92
- # BlockArrays interface
93
- function BlockArrays. findblock (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
94
- # find outer block corresponding to fusion trees
95
- @assert typeof ((f1, f2)) === keytype (trees_block_mapping (ft))
96
- b1 = find_sector_block .(leaves (f1), codomain_axes (ft))
97
- b2 = find_sector_block .(leaves (f2), domain_axes (ft))
98
- return Block (b1... , b2... )
99
- end
100
28
# TBD move to GradedArrays? rename findfirst_sector?
101
29
function find_sector_block (s:: AbstractSector , g:: AbstractGradedUnitRange )
102
30
return findfirst (== (s), sectors (flip_dual (g)))
103
31
end
104
32
105
- function sanitize_axes (raw_legs:: Tuple{Vararg{AbstractGradedUnitRange}} )
106
- legs = promote_sectors (raw_legs)
107
- @assert all (check_unique_sectors .(legs))
108
- return legs
109
- end
110
- sanitize_axes (legs:: BlockedTuple{2,(0, 0)} ) = TrivialSector, legs
111
- function sanitize_axes (raw_legs:: BlockedTuple{2} )
112
- flat_legs = sanitize_axes (Tuple (raw_legs))
113
- return sector_type (first (flat_legs)), BlockedTuple (flat_legs, Val (blocklengths (raw_legs)))
114
- end
115
-
116
- function check_unique_sectors (g:: AbstractGradedUnitRange )
117
- return length (unique (sectors (g))) == blocklength (g)
118
- end
119
-
120
- promote_sectors (legs:: NTuple{<:Any,<:AbstractGradedUnitRange} ) = legs # nothing to do
121
- function promote_sectors (legs)
122
- T = promote_sector_type (legs)
123
- # fuse with trivial to insert all missing arguments inside each GradedAxis
124
- # avoid depending on GradedArrays internals
125
- s0 = trivial (T)
126
- return map_sectors .(s -> only (sectors (to_gradedrange (tensor_product (s0, s)))), legs)
127
- end
128
-
129
- function promote_sector_type (legs)
130
- # fuse trivial sectors to produce unified type
131
- # avoid depending on GradedArrays internals
132
- return sector_type (tensor_product (trivial .(legs)... ))
33
+ # TBD move to GradedArrays?
34
+ function checkaxes (:: Type{Bool} , axes1, axes2)
35
+ return length (axes1) == length (axes2) && all (space_isequal .(axes1, axes2))
133
36
end
134
37
135
- # initialize with already computed data_matrix
136
- function FusionTensor (
137
- x,
138
- codomain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
139
- domain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
140
- )
141
- return FusionTensor (x, tuplemortar ((codomain_legs, domain_legs)))
38
+ # TBD move to GradedArrays?
39
+ checkaxes_dual (axes1, axes2) = checkaxes (axes1, dual .(axes2))
40
+ function checkaxes (ax1, ax2)
41
+ return checkaxes (Bool, ax1, ax2) ||
42
+ throw (DimensionMismatch (lazy " $ax1 does not match $ax2" ))
142
43
end
143
44
144
- function FusionTensor (mat:: AbstractMatrix , legs:: BlockedTuple{2} )
145
- # init with empty data_matrix to construct trees_block_mapping
146
- ft = FusionTensor (eltype (mat), legs)
147
- for b in eachblockstoredindex (mat)
148
- @assert b in eachblockstoredindex (data_matrix (ft)) # check matrix block is allowed
149
- data_matrix (ft)[b] = mat[b]
150
- end
151
- return ft
45
+ function to_blockindexrange (b1:: BlockIndexRange{1} , b2:: BlockIndexRange{1} )
46
+ t = (b1, b2)
47
+ return Block (Block .(t))[to_block_indices .(t)... ]
152
48
end
153
49
154
50
function flip_domain (nonflipped_col_axis, nonflipped_trees_to_ranges)
@@ -159,21 +55,8 @@ function flip_domain(nonflipped_col_axis, nonflipped_trees_to_ranges)
159
55
return col_axis, domain_trees_to_ranges_mapping
160
56
end
161
57
162
- # empty matrix
163
- function FusionTensor (elt:: Type , raw_legs:: BlockedTuple{2} )
164
- S, legs = sanitize_axes (raw_legs)
165
- row_axis, codomain_trees_to_ranges = fuse_axes (S, legs[Block (1 )])
166
- col_axis, domain_trees_to_ranges = flip_domain (fuse_axes (S, dual .(legs[Block (2 )]))... )
167
-
168
- mat = initialize_data_matrix (elt, row_axis, col_axis)
169
- tree_to_block_mapping = intersect_codomain_domain (
170
- codomain_trees_to_ranges, domain_trees_to_ranges
171
- )
172
- return FusionTensor (mat, legs, tree_to_block_mapping)
173
- end
174
-
175
58
function fuse_axes (:: Type{S} , :: Tuple{} ) where {S<: AbstractSector }
176
- fused_axis = gradedrange ([ trivial (S) => 1 ] )
59
+ fused_axis = trivial_axis (S )
177
60
trees_to_ranges_mapping = Dict ([SectorFusionTree {S} () => Block (1 )[1 : 1 ]])
178
61
return fused_axis, trees_to_ranges_mapping
179
62
end
@@ -213,11 +96,6 @@ function compute_inner_ranges(fusion_trees_mult)
213
96
return fused_leg, range_mapping
214
97
end
215
98
216
- function to_blockindexrange (b1:: BlockIndexRange{1} , b2:: BlockIndexRange{1} )
217
- t = (b1, b2)
218
- return Block (Block .(t))[to_block_indices .(t)... ]
219
- end
220
-
221
99
function intersect_codomain_domain (
222
100
codomain_trees_to_ranges_mapping:: Dict{<:SectorFusionTree,<:BlockIndexRange{1}} ,
223
101
domain_trees_to_ranges_mapping:: Dict{<:SectorFusionTree,<:BlockIndexRange{1}} ,
@@ -257,11 +135,112 @@ function initialize_data_matrix(
257
135
return mat
258
136
end
259
137
260
- checkaxes_dual (axes1, axes2) = checkaxes (axes1, dual .(axes2))
261
- function checkaxes (ax1, ax2)
262
- return checkaxes (Bool, ax1, ax2) ||
263
- throw (DimensionMismatch (lazy " $ax1 does not match $ax2" ))
138
+ # ==================================== Definitions =======================================
139
+
140
+ struct FusionTensor{T,N,Axes<: FusionTensorAxes ,Mat<: AbstractMatrix{T} ,Mapping} < :
141
+ AbstractArray{T,N}
142
+ data_matrix:: Mat
143
+ axes:: Axes
144
+ trees_block_mapping:: Mapping
145
+
146
+ # inner constructor to impose constraints on types
147
+ function FusionTensor {T,N,Axes,Mat,Mapping} (
148
+ mat, legs, trees_block_mapping
149
+ ) where {T,N,Axes,Mat,Mapping}
150
+ S = sector_type (legs)
151
+ @assert keytype (trees_block_mapping) < :
152
+ Tuple{<: SectorFusionTree{S} ,<: SectorFusionTree{S} }
153
+ return new {T,N,Axes,Mat,Mapping} (mat, legs, trees_block_mapping)
154
+ end
264
155
end
265
- function checkaxes (:: Type{Bool} , axes1, axes2)
266
- return length (axes1) == length (axes2) && all (space_isequal .(axes1, axes2))
156
+
157
+ # ===================================== Accessors ========================================
158
+
159
+ data_matrix (ft:: FusionTensor ) = ft. data_matrix
160
+ trees_block_mapping (ft:: FusionTensor ) = ft. trees_block_mapping
161
+
162
+ # ==================================== Constructors ======================================
163
+
164
+ function FusionTensor (
165
+ mat:: AbstractMatrix ,
166
+ legs:: FusionTensorAxes ,
167
+ trees_block_mapping:: Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}} ,
168
+ )
169
+ return FusionTensor{
170
+ eltype (mat),length (legs),typeof (legs),typeof (mat),typeof (trees_block_mapping)
171
+ }(
172
+ mat, legs, trees_block_mapping
173
+ )
174
+ end
175
+
176
+ # empty matrix
177
+ function FusionTensor (elt:: Type , legs:: FusionTensorAxes )
178
+ S = sector_type (legs)
179
+ row_axis, codomain_trees_to_ranges = fuse_axes (S, codomain (legs))
180
+ col_axis, domain_trees_to_ranges = flip_domain (fuse_axes (S, dual .(domain (legs)))... )
181
+
182
+ mat = initialize_data_matrix (elt, row_axis, col_axis)
183
+ tree_to_block_mapping = intersect_codomain_domain (
184
+ codomain_trees_to_ranges, domain_trees_to_ranges
185
+ )
186
+ return FusionTensor (mat, legs, tree_to_block_mapping)
187
+ end
188
+
189
+ # constructor from precomputed data_matrix
190
+ function FusionTensor (mat:: AbstractMatrix , legs:: FusionTensorAxes )
191
+ # init with empty data_matrix to construct trees_block_mapping
192
+ ft = FusionTensor (eltype (mat), legs)
193
+ for b in eachblockstoredindex (mat)
194
+ b in eachblockstoredindex (data_matrix (ft)) ||
195
+ throw (ArgumentError (" matrix block $b is not allowed" ))
196
+ data_matrix (ft)[b] = mat[b]
197
+ end
198
+ return ft
199
+ end
200
+
201
+ FusionTensor (x, legs:: BlockedTuple{2} ) = FusionTensor (x, FusionTensorAxes (legs))
202
+
203
+ # constructor from split axes
204
+ function FusionTensor (
205
+ x,
206
+ codomain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
207
+ domain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
208
+ )
209
+ return FusionTensor (x, tuplemortar ((codomain_legs, domain_legs)))
210
+ end
211
+
212
+ # ================================ BlockArrays interface =================================
213
+
214
+ function BlockArrays. findblock (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
215
+ # find outer block corresponding to fusion trees
216
+ @assert typeof ((f1, f2)) === keytype (trees_block_mapping (ft))
217
+ b1 = find_sector_block .(leaves (f1), codomain_axes (ft))
218
+ b2 = find_sector_block .(leaves (f2), domain_axes (ft))
219
+ return Block (b1... , b2... )
220
+ end
221
+
222
+ # ============================== GradedArrays interface ==================================
223
+
224
+ function GradedArrays. sector_type (:: Type{FT} ) where {FT<: FusionTensor }
225
+ return sector_type (type_parameters (FT, 3 ))
226
+ end
227
+
228
+ # ============================== FusionTensor interface ==================================
229
+
230
+ # misc access
231
+ codomain_axes (ft:: FusionTensor ) = codomain (axes (ft))
232
+
233
+ domain_axes (ft:: FusionTensor ) = domain (axes (ft))
234
+
235
+ codomain_axis (ft:: FusionTensor ) = fused_codomain (axes (ft))
236
+
237
+ domain_axis (ft:: FusionTensor ) = fused_domain (axes (ft))
238
+
239
+ ndims_codomain (ft:: FusionTensor ) = length_codomain (axes (ft))
240
+
241
+ ndims_domain (ft:: FusionTensor ) = length_domain (axes (ft))
242
+
243
+ function charge_block_size (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
244
+ b = Tuple (findblock (ft, f1, f2))
245
+ return ntuple (i -> sector_multiplicity (axes (ft, i)[b[i]]), ndims (ft))
267
246
end
0 commit comments