@@ -19,136 +19,32 @@ using GradedArrays:
1919 sectormergesort,
2020 sectors,
2121 space_isequal
22- using TensorAlgebra: BlockedTuple, tuplemortar
22+ using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
2323using TensorProducts: tensor_product
2424using TypeParameterAccessors: type_parameters
2525
26- struct FusionTensor{T,N,Axes,Mat<: AbstractMatrix{T} ,Mapping} <: AbstractArray{T,N}
27- data_matrix:: Mat
28- axes:: Axes
29- trees_block_mapping:: Mapping
30-
31- # inner constructor to impose constraints on types
32- function FusionTensor {T,N,Axes,Mat,Mapping} (
33- mat, legs, trees_block_mapping
34- ) where {T,N,Axes,Mat,Mapping}
35- S = isempty (legs) ? TrivialSector : sector_type (first (legs))
36- @assert keytype (trees_block_mapping) < :
37- Tuple{<: SectorFusionTree{S} ,<: SectorFusionTree{S} }
38- @assert all (sector_type .(Tuple (legs)) .=== S)
39- return new {T,N,Axes,Mat,Mapping} (mat, legs, trees_block_mapping)
40- end
41- end
42-
43- function FusionTensor (
44- mat:: AbstractMatrix ,
45- legs:: BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractGradedUnitRange}}} ,
46- trees_block_mapping:: Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}} ,
47- )
48- return FusionTensor{
49- eltype (mat),length (legs),typeof (legs),typeof (mat),typeof (trees_block_mapping)
50- }(
51- mat, legs, trees_block_mapping
52- )
53- end
54-
55- # getters
56- data_matrix (ft:: FusionTensor ) = ft. data_matrix
57- trees_block_mapping (ft:: FusionTensor ) = ft. trees_block_mapping
58-
59- # misc access
60- codomain_axes (ft:: FusionTensor ) = axes (ft)[Block (1 )]
61- domain_axes (ft:: FusionTensor ) = axes (ft)[Block (2 )]
62- ndims_codomain (ft:: FusionTensor ) = length (codomain_axes (ft))
63- ndims_domain (ft:: FusionTensor ) = length (domain_axes (ft))
64-
65- dummy_axis (ft:: FusionTensor ) = dummy_axis (sector_type (ft))
66- dummy_axis (:: Type{S} ) where {S<: AbstractSector } = gradedrange ([trivial (S) => 1 ])
67-
68- function codomain_axis (ft:: FusionTensor )
69- if ndims_codomain (ft) == 0
70- return dummy_axis (ft)
71- end
72- return ⊗ (codomain_axes (ft)... )
73- end
74- function domain_axis (ft:: FusionTensor )
75- if ndims_domain (ft) == 0
76- return dummy_axis (ft)
77- end
78- return dual (⊗ (dual .(domain_axes (ft))... ))
79- end
80- function charge_block_size (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
81- b = Tuple (findblock (ft, f1, f2))
82- return ntuple (i -> sector_multiplicity (axes (ft, i)[b[i]]), ndims (ft))
83- end
26+ # ======================================= Misc ===========================================
8427
85- # GradedArrays interface
86- function GradedArrays. sector_type (
87- :: Type {<: FusionTensor{<:Any,<:Any,<:Any,<:Any,<:Dict{<:Tuple{<:Any,F}}} }
88- ) where {F}
89- return sector_type (F)
90- end
91-
92- # BlockArrays interface
93- function BlockArrays. findblock (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
94- # find outer block corresponding to fusion trees
95- @assert typeof ((f1, f2)) === keytype (trees_block_mapping (ft))
96- b1 = find_sector_block .(leaves (f1), codomain_axes (ft))
97- b2 = find_sector_block .(leaves (f2), domain_axes (ft))
98- return Block (b1... , b2... )
99- end
10028# TBD move to GradedArrays? rename findfirst_sector?
10129function find_sector_block (s:: AbstractSector , g:: AbstractGradedUnitRange )
10230 return findfirst (== (s), sectors (flip_dual (g)))
10331end
10432
105- function sanitize_axes (raw_legs:: Tuple{Vararg{AbstractGradedUnitRange}} )
106- legs = promote_sectors (raw_legs)
107- @assert all (check_unique_sectors .(legs))
108- return legs
109- end
110- sanitize_axes (legs:: BlockedTuple{2,(0, 0)} ) = TrivialSector, legs
111- function sanitize_axes (raw_legs:: BlockedTuple{2} )
112- flat_legs = sanitize_axes (Tuple (raw_legs))
113- return sector_type (first (flat_legs)), BlockedTuple (flat_legs, Val (blocklengths (raw_legs)))
114- end
115-
116- function check_unique_sectors (g:: AbstractGradedUnitRange )
117- return length (unique (sectors (g))) == blocklength (g)
118- end
119-
120- promote_sectors (legs:: NTuple{<:Any,<:AbstractGradedUnitRange} ) = legs # nothing to do
121- function promote_sectors (legs)
122- T = promote_sector_type (legs)
123- # fuse with trivial to insert all missing arguments inside each GradedAxis
124- # avoid depending on GradedArrays internals
125- s0 = trivial (T)
126- return map_sectors .(s -> only (sectors (to_gradedrange (tensor_product (s0, s)))), legs)
127- end
128-
129- function promote_sector_type (legs)
130- # fuse trivial sectors to produce unified type
131- # avoid depending on GradedArrays internals
132- return sector_type (tensor_product (trivial .(legs)... ))
33+ # TBD move to GradedArrays?
34+ function checkaxes (:: Type{Bool} , axes1, axes2)
35+ return length (axes1) == length (axes2) && all (space_isequal .(axes1, axes2))
13336end
13437
135- # initialize with already computed data_matrix
136- function FusionTensor (
137- x,
138- codomain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
139- domain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
140- )
141- return FusionTensor (x, tuplemortar ((codomain_legs, domain_legs)))
38+ # TBD move to GradedArrays?
39+ checkaxes_dual (axes1, axes2) = checkaxes (axes1, dual .(axes2))
40+ function checkaxes (ax1, ax2)
41+ return checkaxes (Bool, ax1, ax2) ||
42+ throw (DimensionMismatch (lazy " $ax1 does not match $ax2" ))
14243end
14344
144- function FusionTensor (mat:: AbstractMatrix , legs:: BlockedTuple{2} )
145- # init with empty data_matrix to construct trees_block_mapping
146- ft = FusionTensor (eltype (mat), legs)
147- for b in eachblockstoredindex (mat)
148- @assert b in eachblockstoredindex (data_matrix (ft)) # check matrix block is allowed
149- data_matrix (ft)[b] = mat[b]
150- end
151- return ft
45+ function to_blockindexrange (b1:: BlockIndexRange{1} , b2:: BlockIndexRange{1} )
46+ t = (b1, b2)
47+ return Block (Block .(t))[to_block_indices .(t)... ]
15248end
15349
15450function flip_domain (nonflipped_col_axis, nonflipped_trees_to_ranges)
@@ -159,21 +55,8 @@ function flip_domain(nonflipped_col_axis, nonflipped_trees_to_ranges)
15955 return col_axis, domain_trees_to_ranges_mapping
16056end
16157
162- # empty matrix
163- function FusionTensor (elt:: Type , raw_legs:: BlockedTuple{2} )
164- S, legs = sanitize_axes (raw_legs)
165- row_axis, codomain_trees_to_ranges = fuse_axes (S, legs[Block (1 )])
166- col_axis, domain_trees_to_ranges = flip_domain (fuse_axes (S, dual .(legs[Block (2 )]))... )
167-
168- mat = initialize_data_matrix (elt, row_axis, col_axis)
169- tree_to_block_mapping = intersect_codomain_domain (
170- codomain_trees_to_ranges, domain_trees_to_ranges
171- )
172- return FusionTensor (mat, legs, tree_to_block_mapping)
173- end
174-
17558function fuse_axes (:: Type{S} , :: Tuple{} ) where {S<: AbstractSector }
176- fused_axis = gradedrange ([ trivial (S) => 1 ] )
59+ fused_axis = trivial_axis (S )
17760 trees_to_ranges_mapping = Dict ([SectorFusionTree {S} () => Block (1 )[1 : 1 ]])
17861 return fused_axis, trees_to_ranges_mapping
17962end
@@ -213,11 +96,6 @@ function compute_inner_ranges(fusion_trees_mult)
21396 return fused_leg, range_mapping
21497end
21598
216- function to_blockindexrange (b1:: BlockIndexRange{1} , b2:: BlockIndexRange{1} )
217- t = (b1, b2)
218- return Block (Block .(t))[to_block_indices .(t)... ]
219- end
220-
22199function intersect_codomain_domain (
222100 codomain_trees_to_ranges_mapping:: Dict{<:SectorFusionTree,<:BlockIndexRange{1}} ,
223101 domain_trees_to_ranges_mapping:: Dict{<:SectorFusionTree,<:BlockIndexRange{1}} ,
@@ -257,11 +135,112 @@ function initialize_data_matrix(
257135 return mat
258136end
259137
260- checkaxes_dual (axes1, axes2) = checkaxes (axes1, dual .(axes2))
261- function checkaxes (ax1, ax2)
262- return checkaxes (Bool, ax1, ax2) ||
263- throw (DimensionMismatch (lazy " $ax1 does not match $ax2" ))
138+ # ==================================== Definitions =======================================
139+
140+ struct FusionTensor{T,N,Axes<: FusionTensorAxes ,Mat<: AbstractMatrix{T} ,Mapping} < :
141+ AbstractArray{T,N}
142+ data_matrix:: Mat
143+ axes:: Axes
144+ trees_block_mapping:: Mapping
145+
146+ # inner constructor to impose constraints on types
147+ function FusionTensor {T,N,Axes,Mat,Mapping} (
148+ mat, legs, trees_block_mapping
149+ ) where {T,N,Axes,Mat,Mapping}
150+ S = sector_type (legs)
151+ @assert keytype (trees_block_mapping) < :
152+ Tuple{<: SectorFusionTree{S} ,<: SectorFusionTree{S} }
153+ return new {T,N,Axes,Mat,Mapping} (mat, legs, trees_block_mapping)
154+ end
264155end
265- function checkaxes (:: Type{Bool} , axes1, axes2)
266- return length (axes1) == length (axes2) && all (space_isequal .(axes1, axes2))
156+
157+ # ===================================== Accessors ========================================
158+
159+ data_matrix (ft:: FusionTensor ) = ft. data_matrix
160+ trees_block_mapping (ft:: FusionTensor ) = ft. trees_block_mapping
161+
162+ # ==================================== Constructors ======================================
163+
164+ function FusionTensor (
165+ mat:: AbstractMatrix ,
166+ legs:: FusionTensorAxes ,
167+ trees_block_mapping:: Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}} ,
168+ )
169+ return FusionTensor{
170+ eltype (mat),length (legs),typeof (legs),typeof (mat),typeof (trees_block_mapping)
171+ }(
172+ mat, legs, trees_block_mapping
173+ )
174+ end
175+
176+ # empty matrix
177+ function FusionTensor (elt:: Type , legs:: FusionTensorAxes )
178+ S = sector_type (legs)
179+ row_axis, codomain_trees_to_ranges = fuse_axes (S, codomain (legs))
180+ col_axis, domain_trees_to_ranges = flip_domain (fuse_axes (S, dual .(domain (legs)))... )
181+
182+ mat = initialize_data_matrix (elt, row_axis, col_axis)
183+ tree_to_block_mapping = intersect_codomain_domain (
184+ codomain_trees_to_ranges, domain_trees_to_ranges
185+ )
186+ return FusionTensor (mat, legs, tree_to_block_mapping)
187+ end
188+
189+ # constructor from precomputed data_matrix
190+ function FusionTensor (mat:: AbstractMatrix , legs:: FusionTensorAxes )
191+ # init with empty data_matrix to construct trees_block_mapping
192+ ft = FusionTensor (eltype (mat), legs)
193+ for b in eachblockstoredindex (mat)
194+ b in eachblockstoredindex (data_matrix (ft)) ||
195+ throw (ArgumentError (" matrix block $b is not allowed" ))
196+ data_matrix (ft)[b] = mat[b]
197+ end
198+ return ft
199+ end
200+
201+ FusionTensor (x, legs:: BlockedTuple{2} ) = FusionTensor (x, FusionTensorAxes (legs))
202+
203+ # constructor from split axes
204+ function FusionTensor (
205+ x,
206+ codomain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
207+ domain_legs:: Tuple{Vararg{AbstractGradedUnitRange}} ,
208+ )
209+ return FusionTensor (x, tuplemortar ((codomain_legs, domain_legs)))
210+ end
211+
212+ # ================================ BlockArrays interface =================================
213+
214+ function BlockArrays. findblock (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
215+ # find outer block corresponding to fusion trees
216+ @assert typeof ((f1, f2)) === keytype (trees_block_mapping (ft))
217+ b1 = find_sector_block .(leaves (f1), codomain_axes (ft))
218+ b2 = find_sector_block .(leaves (f2), domain_axes (ft))
219+ return Block (b1... , b2... )
220+ end
221+
222+ # ============================== GradedArrays interface ==================================
223+
224+ function GradedArrays. sector_type (:: Type{FT} ) where {FT<: FusionTensor }
225+ return sector_type (type_parameters (FT, 3 ))
226+ end
227+
228+ # ============================== FusionTensor interface ==================================
229+
230+ # misc access
231+ codomain_axes (ft:: FusionTensor ) = codomain (axes (ft))
232+
233+ domain_axes (ft:: FusionTensor ) = domain (axes (ft))
234+
235+ codomain_axis (ft:: FusionTensor ) = fused_codomain (axes (ft))
236+
237+ domain_axis (ft:: FusionTensor ) = fused_domain (axes (ft))
238+
239+ ndims_codomain (ft:: FusionTensor ) = length_codomain (axes (ft))
240+
241+ ndims_domain (ft:: FusionTensor ) = length_domain (axes (ft))
242+
243+ function charge_block_size (ft:: FusionTensor , f1:: SectorFusionTree , f2:: SectorFusionTree )
244+ b = Tuple (findblock (ft, f1, f2))
245+ return ntuple (i -> sector_multiplicity (axes (ft, i)[b[i]]), ndims (ft))
267246end
0 commit comments