Skip to content

Commit 6a99562

Browse files
committed
reorder file
1 parent ddfff17 commit 6a99562

File tree

2 files changed

+120
-99
lines changed

2 files changed

+120
-99
lines changed

src/fusiontensor/fusiontensor.jl

Lines changed: 119 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -23,90 +23,28 @@ using TensorAlgebra: BlockedTuple, tuplemortar
2323
using TensorProducts: tensor_product
2424
using TypeParameterAccessors: type_parameters
2525

26-
struct FusionTensor{T,N,Axes<:FusionTensorAxes,Mat<:AbstractMatrix{T},Mapping} <:
27-
AbstractArray{T,N}
28-
data_matrix::Mat
29-
axes::Axes
30-
trees_block_mapping::Mapping
31-
32-
# inner constructor to impose constraints on types
33-
function FusionTensor{T,N,Axes,Mat,Mapping}(
34-
mat, legs, trees_block_mapping
35-
) where {T,N,Axes,Mat,Mapping}
36-
S = sector_type(legs)
37-
@assert keytype(trees_block_mapping) <:
38-
Tuple{<:SectorFusionTree{S},<:SectorFusionTree{S}}
39-
return new{T,N,Axes,Mat,Mapping}(mat, legs, trees_block_mapping)
40-
end
41-
end
26+
# ======================================= Misc ===========================================
4227

43-
function FusionTensor(
44-
mat::AbstractMatrix,
45-
legs::FusionTensorAxes,
46-
trees_block_mapping::Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}},
47-
)
48-
return FusionTensor{
49-
eltype(mat),length(legs),typeof(legs),typeof(mat),typeof(trees_block_mapping)
50-
}(
51-
mat, legs, trees_block_mapping
52-
)
53-
end
54-
55-
# getters
56-
data_matrix(ft::FusionTensor) = ft.data_matrix
57-
trees_block_mapping(ft::FusionTensor) = ft.trees_block_mapping
58-
59-
# misc access
60-
for f in [
61-
:(codomain_axes),
62-
:(codomain_axis),
63-
:(domain_axes),
64-
:(domain_axis),
65-
:(ndims_codomain),
66-
:(ndims_domain),
67-
:(GradedArrays.sector_type),
68-
]
69-
@eval $f(ft::FusionTensor) = $f(axes(ft))
70-
end
71-
72-
function charge_block_size(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
73-
b = Tuple(findblock(ft, f1, f2))
74-
return ntuple(i -> sector_multiplicity(axes(ft, i)[b[i]]), ndims(ft))
75-
end
76-
77-
# BlockArrays interface
78-
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
79-
# find outer block corresponding to fusion trees
80-
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
81-
b1 = find_sector_block.(leaves(f1), codomain_axes(ft))
82-
b2 = find_sector_block.(leaves(f2), domain_axes(ft))
83-
return Block(b1..., b2...)
84-
end
8528
# TBD move to GradedArrays? rename findfirst_sector?
8629
function find_sector_block(s::AbstractSector, g::AbstractGradedUnitRange)
8730
return findfirst(==(s), sectors(flip_dual(g)))
8831
end
8932

90-
# constructor from split axes
91-
function FusionTensor(
92-
x,
93-
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
94-
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
95-
)
96-
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
33+
# TBD move to GradedArrays?
34+
function checkaxes(::Type{Bool}, axes1, axes2)
35+
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
9736
end
9837

99-
FusionTensor(x, legs::BlockedTuple{2}) = FusionTensor(x, FusionTensorAxes(legs))
38+
# TBD move to GradedArrays?
39+
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
40+
function checkaxes(ax1, ax2)
41+
return checkaxes(Bool, ax1, ax2) ||
42+
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
43+
end
10044

101-
#constructor from precomputed data_matrix
102-
function FusionTensor(mat::AbstractMatrix, legs::FusionTensorAxes)
103-
# init with empty data_matrix to construct trees_block_mapping
104-
ft = FusionTensor(eltype(mat), legs)
105-
for b in eachblockstoredindex(mat)
106-
@assert b in eachblockstoredindex(data_matrix(ft)) # check matrix block is allowed
107-
data_matrix(ft)[b] = mat[b]
108-
end
109-
return ft
45+
function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
46+
t = (b1, b2)
47+
return Block(Block.(t))[to_block_indices.(t)...]
11048
end
11149

11250
function flip_domain(nonflipped_col_axis, nonflipped_trees_to_ranges)
@@ -117,19 +55,6 @@ function flip_domain(nonflipped_col_axis, nonflipped_trees_to_ranges)
11755
return col_axis, domain_trees_to_ranges_mapping
11856
end
11957

120-
# empty matrix
121-
function FusionTensor(elt::Type, legs::FusionTensorAxes)
122-
S = sector_type(legs)
123-
row_axis, codomain_trees_to_ranges = fuse_axes(S, codomain_axes(legs))
124-
col_axis, domain_trees_to_ranges = flip_domain(fuse_axes(S, dual.(domain_axes(legs)))...)
125-
126-
mat = initialize_data_matrix(elt, row_axis, col_axis)
127-
tree_to_block_mapping = intersect_codomain_domain(
128-
codomain_trees_to_ranges, domain_trees_to_ranges
129-
)
130-
return FusionTensor(mat, legs, tree_to_block_mapping)
131-
end
132-
13358
function fuse_axes(::Type{S}, ::Tuple{}) where {S<:AbstractSector}
13459
fused_axis = dummy_axis(S)
13560
trees_to_ranges_mapping = Dict([SectorFusionTree{S}() => Block(1)[1:1]])
@@ -171,11 +96,6 @@ function compute_inner_ranges(fusion_trees_mult)
17196
return fused_leg, range_mapping
17297
end
17398

174-
function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
175-
t = (b1, b2)
176-
return Block(Block.(t))[to_block_indices.(t)...]
177-
end
178-
17999
function intersect_codomain_domain(
180100
codomain_trees_to_ranges_mapping::Dict{<:SectorFusionTree,<:BlockIndexRange{1}},
181101
domain_trees_to_ranges_mapping::Dict{<:SectorFusionTree,<:BlockIndexRange{1}},
@@ -215,11 +135,111 @@ function initialize_data_matrix(
215135
return mat
216136
end
217137

218-
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
219-
function checkaxes(ax1, ax2)
220-
return checkaxes(Bool, ax1, ax2) ||
221-
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
138+
# ==================================== Definitions =======================================
139+
140+
struct FusionTensor{T,N,Axes<:FusionTensorAxes,Mat<:AbstractMatrix{T},Mapping} <:
141+
AbstractArray{T,N}
142+
data_matrix::Mat
143+
axes::Axes
144+
trees_block_mapping::Mapping
145+
146+
# inner constructor to impose constraints on types
147+
function FusionTensor{T,N,Axes,Mat,Mapping}(
148+
mat, legs, trees_block_mapping
149+
) where {T,N,Axes,Mat,Mapping}
150+
S = sector_type(legs)
151+
@assert keytype(trees_block_mapping) <:
152+
Tuple{<:SectorFusionTree{S},<:SectorFusionTree{S}}
153+
return new{T,N,Axes,Mat,Mapping}(mat, legs, trees_block_mapping)
154+
end
222155
end
223-
function checkaxes(::Type{Bool}, axes1, axes2)
224-
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
156+
157+
# ===================================== Accessors ========================================
158+
159+
data_matrix(ft::FusionTensor) = ft.data_matrix
160+
trees_block_mapping(ft::FusionTensor) = ft.trees_block_mapping
161+
162+
# ==================================== Constructors ======================================
163+
164+
function FusionTensor(
165+
mat::AbstractMatrix,
166+
legs::FusionTensorAxes,
167+
trees_block_mapping::Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}},
168+
)
169+
return FusionTensor{
170+
eltype(mat),length(legs),typeof(legs),typeof(mat),typeof(trees_block_mapping)
171+
}(
172+
mat, legs, trees_block_mapping
173+
)
174+
end
175+
176+
# empty matrix
177+
function FusionTensor(elt::Type, legs::FusionTensorAxes)
178+
S = sector_type(legs)
179+
row_axis, codomain_trees_to_ranges = fuse_axes(S, codomain_axes(legs))
180+
col_axis, domain_trees_to_ranges = flip_domain(fuse_axes(S, dual.(domain_axes(legs)))...)
181+
182+
mat = initialize_data_matrix(elt, row_axis, col_axis)
183+
tree_to_block_mapping = intersect_codomain_domain(
184+
codomain_trees_to_ranges, domain_trees_to_ranges
185+
)
186+
return FusionTensor(mat, legs, tree_to_block_mapping)
187+
end
188+
189+
#constructor from precomputed data_matrix
190+
function FusionTensor(mat::AbstractMatrix, legs::FusionTensorAxes)
191+
# init with empty data_matrix to construct trees_block_mapping
192+
ft = FusionTensor(eltype(mat), legs)
193+
for b in eachblockstoredindex(mat)
194+
b in eachblockstoredindex(data_matrix(ft)) ||
195+
throw(ArgumentError("matrix block $b is not allowed"))
196+
data_matrix(ft)[b] = mat[b]
197+
end
198+
return ft
199+
end
200+
201+
FusionTensor(x, legs::BlockedTuple{2}) = FusionTensor(x, FusionTensorAxes(legs))
202+
203+
# constructor from split axes
204+
function FusionTensor(
205+
x,
206+
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
207+
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
208+
)
209+
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
210+
end
211+
212+
# ================================ BlockArrays interface =================================
213+
214+
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
215+
# find outer block corresponding to fusion trees
216+
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
217+
b1 = find_sector_block.(leaves(f1), codomain_axes(ft))
218+
b2 = find_sector_block.(leaves(f2), domain_axes(ft))
219+
return Block(b1..., b2...)
220+
end
221+
222+
# ============================== GradedArrays interface ==================================
223+
224+
function GradedArrays.sector_type(::Type{FT}) where {FT<:FusionTensor}
225+
return sector_type(type_parameters(FT, 3))
226+
end
227+
228+
# ============================== FusionTensor interface ==================================
229+
230+
# misc access
231+
for f in [
232+
:(codomain_axes),
233+
:(codomain_axis),
234+
:(domain_axes),
235+
:(domain_axis),
236+
:(ndims_codomain),
237+
:(ndims_domain),
238+
]
239+
@eval $f(ft::FusionTensor) = $f(axes(ft))
240+
end
241+
242+
function charge_block_size(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
243+
b = Tuple(findblock(ft, f1, f2))
244+
return ntuple(i -> sector_multiplicity(axes(ft, i)[b[i]]), ndims(ft))
225245
end

test/test_basics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ include("setup.jl")
5858
@test isnothing(check_sanity(ft0))
5959
@test isnothing(check_sanity(ft1))
6060
@test sector_type(ft1) === U1{Int}
61+
@test sector_type(typeof(ft1)) === U1{Int}
6162

6263
# Base methods
6364
@test eltype(ft1) === Float64

0 commit comments

Comments
 (0)