@@ -23,90 +23,28 @@ using TensorAlgebra: BlockedTuple, tuplemortar
2323using TensorProducts: tensor_product
2424using TypeParameterAccessors: type_parameters
2525
26- struct FusionTensor{T,N,Axes<: FusionTensorAxes ,Mat<: AbstractMatrix{T} ,Mapping} < :
27- AbstractArray{T,N}
28- data_matrix:: Mat
29- axes:: Axes
30- trees_block_mapping:: Mapping
31-
32- # inner constructor to impose constraints on types
33- function FusionTensor {T,N,Axes,Mat,Mapping} (
34- mat, legs, trees_block_mapping
35- ) where {T,N,Axes,Mat,Mapping}
36- S = sector_type (legs)
37- @assert keytype (trees_block_mapping) < :
38- Tuple{<: SectorFusionTree{S} ,<: SectorFusionTree{S} }
39- return new {T,N,Axes,Mat,Mapping} (mat, legs, trees_block_mapping)
40- end
41- end
26+ # ======================================= Misc ===========================================
4227
43- function FusionTensor (
44- mat:: AbstractMatrix ,
45- legs:: FusionTensorAxes ,
46- trees_block_mapping:: Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}} ,
47- )
48- return FusionTensor{
49- eltype (mat),length (legs),typeof (legs),typeof (mat),typeof (trees_block_mapping)
50- }(
51- mat, legs, trees_block_mapping
52- )
53- end
54-
55- # getters
56- data_matrix (ft:: FusionTensor ) = ft. data_matrix
57- trees_block_mapping (ft:: FusionTensor ) = ft. trees_block_mapping
58-
59- # misc access
60- for f in [
61- :(codomain_axes),
62- :(codomain_axis),
63- :(domain_axes),
64- :(domain_axis),
65- :(ndims_codomain),
66- :(ndims_domain),
67- :(GradedArrays. sector_type),
68- ]
69- @eval $ f (ft:: FusionTensor ) = $ f (axes (ft))
70- end
71-
72- function charge_block_size (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
73- b = Tuple (findblock (ft, f1, f2))
74- return ntuple (i -> sector_multiplicity (axes (ft, i)[b[i]]), ndims (ft))
75- end
76-
77- # BlockArrays interface
78- function BlockArrays. findblock (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
79- # find outer block corresponding to fusion trees
80- @assert typeof ((f1, f2)) === keytype (trees_block_mapping (ft))
81- b1 = find_sector_block .(leaves (f1), codomain_axes (ft))
82- b2 = find_sector_block .(leaves (f2), domain_axes (ft))
83- return Block (b1... , b2... )
84- end
8528# TBD move to GradedArrays? rename findfirst_sector?
8629function find_sector_block (s:: AbstractSector , g:: AbstractGradedUnitRange )
8730 return findfirst (== (s), sectors (flip_dual (g)))
8831end
8932
90- # constructor from split axes
91- function FusionTensor (
92- x,
93- codomain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
94- domain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
95- )
96- return FusionTensor (x, tuplemortar ((codomain_legs, domain_legs)))
33+ # TBD move to GradedArrays?
34+ function checkaxes (:: Type{Bool} , axes1, axes2)
35+ return length (axes1) == length (axes2) && all (space_isequal .(axes1, axes2))
9736end
9837
99- FusionTensor (x, legs:: BlockedTuple{2} ) = FusionTensor (x, FusionTensorAxes (legs))
38+ # TBD move to GradedArrays?
39+ checkaxes_dual (axes1, axes2) = checkaxes (axes1, dual .(axes2))
40+ function checkaxes (ax1, ax2)
41+ return checkaxes (Bool, ax1, ax2) ||
42+ throw (DimensionMismatch (lazy " $ax1 does not match $ax2" ))
43+ end
10044
101- # constructor from precomputed data_matrix
102- function FusionTensor (mat:: AbstractMatrix , legs:: FusionTensorAxes )
103- # init with empty data_matrix to construct trees_block_mapping
104- ft = FusionTensor (eltype (mat), legs)
105- for b in eachblockstoredindex (mat)
106- @assert b in eachblockstoredindex (data_matrix (ft)) # check matrix block is allowed
107- data_matrix (ft)[b] = mat[b]
108- end
109- return ft
45+ function to_blockindexrange (b1:: BlockIndexRange{1} , b2:: BlockIndexRange{1} )
46+ t = (b1, b2)
47+ return Block (Block .(t))[to_block_indices .(t)... ]
11048end
11149
11250function flip_domain (nonflipped_col_axis, nonflipped_trees_to_ranges)
@@ -117,19 +55,6 @@ function flip_domain(nonflipped_col_axis, nonflipped_trees_to_ranges)
11755 return col_axis, domain_trees_to_ranges_mapping
11856end
11957
120- # empty matrix
121- function FusionTensor (elt:: Type , legs:: FusionTensorAxes )
122- S = sector_type (legs)
123- row_axis, codomain_trees_to_ranges = fuse_axes (S, codomain_axes (legs))
124- col_axis, domain_trees_to_ranges = flip_domain (fuse_axes (S, dual .(domain_axes (legs)))... )
125-
126- mat = initialize_data_matrix (elt, row_axis, col_axis)
127- tree_to_block_mapping = intersect_codomain_domain (
128- codomain_trees_to_ranges, domain_trees_to_ranges
129- )
130- return FusionTensor (mat, legs, tree_to_block_mapping)
131- end
132-
13358function fuse_axes (:: Type{S} , :: Tuple{} ) where {S<: AbstractSector }
13459 fused_axis = dummy_axis (S)
13560 trees_to_ranges_mapping = Dict ([SectorFusionTree {S} () => Block (1 )[1 : 1 ]])
@@ -171,11 +96,6 @@ function compute_inner_ranges(fusion_trees_mult)
17196 return fused_leg, range_mapping
17297end
17398
174- function to_blockindexrange (b1:: BlockIndexRange{1} , b2:: BlockIndexRange{1} )
175- t = (b1, b2)
176- return Block (Block .(t))[to_block_indices .(t)... ]
177- end
178-
17999function intersect_codomain_domain (
180100 codomain_trees_to_ranges_mapping:: Dict{<:SectorFusionTree,<:BlockIndexRange{1}} ,
181101 domain_trees_to_ranges_mapping:: Dict{<:SectorFusionTree,<:BlockIndexRange{1}} ,
@@ -215,11 +135,111 @@ function initialize_data_matrix(
215135 return mat
216136end
217137
218- checkaxes_dual (axes1, axes2) = checkaxes (axes1, dual .(axes2))
219- function checkaxes (ax1, ax2)
220- return checkaxes (Bool, ax1, ax2) ||
221- throw (DimensionMismatch (lazy " $ax1 does not match $ax2" ))
138+ # ==================================== Definitions =======================================
139+
140+ struct FusionTensor{T,N,Axes<: FusionTensorAxes ,Mat<: AbstractMatrix{T} ,Mapping} < :
141+ AbstractArray{T,N}
142+ data_matrix:: Mat
143+ axes:: Axes
144+ trees_block_mapping:: Mapping
145+
146+ # inner constructor to impose constraints on types
147+ function FusionTensor {T,N,Axes,Mat,Mapping} (
148+ mat, legs, trees_block_mapping
149+ ) where {T,N,Axes,Mat,Mapping}
150+ S = sector_type (legs)
151+ @assert keytype (trees_block_mapping) < :
152+ Tuple{<: SectorFusionTree{S} ,<: SectorFusionTree{S} }
153+ return new {T,N,Axes,Mat,Mapping} (mat, legs, trees_block_mapping)
154+ end
222155end
223- function checkaxes (:: Type{Bool} , axes1, axes2)
224- return length (axes1) == length (axes2) && all (space_isequal .(axes1, axes2))
156+
157+ # ===================================== Accessors ========================================
158+
159+ data_matrix (ft:: FusionTensor ) = ft. data_matrix
160+ trees_block_mapping (ft:: FusionTensor ) = ft. trees_block_mapping
161+
162+ # ==================================== Constructors ======================================
163+
164+ function FusionTensor (
165+ mat:: AbstractMatrix ,
166+ legs:: FusionTensorAxes ,
167+ trees_block_mapping:: Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}} ,
168+ )
169+ return FusionTensor{
170+ eltype (mat),length (legs),typeof (legs),typeof (mat),typeof (trees_block_mapping)
171+ }(
172+ mat, legs, trees_block_mapping
173+ )
174+ end
175+
176+ # empty matrix
177+ function FusionTensor (elt:: Type , legs:: FusionTensorAxes )
178+ S = sector_type (legs)
179+ row_axis, codomain_trees_to_ranges = fuse_axes (S, codomain_axes (legs))
180+ col_axis, domain_trees_to_ranges = flip_domain (fuse_axes (S, dual .(domain_axes (legs)))... )
181+
182+ mat = initialize_data_matrix (elt, row_axis, col_axis)
183+ tree_to_block_mapping = intersect_codomain_domain (
184+ codomain_trees_to_ranges, domain_trees_to_ranges
185+ )
186+ return FusionTensor (mat, legs, tree_to_block_mapping)
187+ end
188+
189+ # constructor from precomputed data_matrix
190+ function FusionTensor (mat:: AbstractMatrix , legs:: FusionTensorAxes )
191+ # init with empty data_matrix to construct trees_block_mapping
192+ ft = FusionTensor (eltype (mat), legs)
193+ for b in eachblockstoredindex (mat)
194+ b in eachblockstoredindex (data_matrix (ft)) ||
195+ throw (ArgumentError (" matrix block $b is not allowed" ))
196+ data_matrix (ft)[b] = mat[b]
197+ end
198+ return ft
199+ end
200+
201+ FusionTensor (x, legs:: BlockedTuple{2} ) = FusionTensor (x, FusionTensorAxes (legs))
202+
203+ # constructor from split axes
204+ function FusionTensor (
205+ x,
206+ codomain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
207+ domain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
208+ )
209+ return FusionTensor (x, tuplemortar ((codomain_legs, domain_legs)))
210+ end
211+
212+ # ================================ BlockArrays interface =================================
213+
214+ function BlockArrays. findblock (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
215+ # find outer block corresponding to fusion trees
216+ @assert typeof ((f1, f2)) === keytype (trees_block_mapping (ft))
217+ b1 = find_sector_block .(leaves (f1), codomain_axes (ft))
218+ b2 = find_sector_block .(leaves (f2), domain_axes (ft))
219+ return Block (b1... , b2... )
220+ end
221+
222+ # ============================== GradedArrays interface ==================================
223+
224+ function GradedArrays. sector_type (:: Type{FT} ) where {FT<: FusionTensor }
225+ return sector_type (type_parameters (FT, 3 ))
226+ end
227+
228+ # ============================== FusionTensor interface ==================================
229+
230+ # misc access
231+ for f in [
232+ :(codomain_axes),
233+ :(codomain_axis),
234+ :(domain_axes),
235+ :(domain_axis),
236+ :(ndims_codomain),
237+ :(ndims_domain),
238+ ]
239+ @eval $ f (ft:: FusionTensor ) = $ f (axes (ft))
240+ end
241+
242+ function charge_block_size (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
243+ b = Tuple (findblock (ft, f1, f2))
244+ return ntuple (i -> sector_multiplicity (axes (ft, i)[b[i]]), ndims (ft))
225245end
0 commit comments