Skip to content

Commit 260760f

Browse files
committed
remove FusedAxes
1 parent 20a3119 commit 260760f

File tree

3 files changed

+87
-121
lines changed

3 files changed

+87
-121
lines changed

src/FusionTensors.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module FusionTensors
22

33
include("fusion_trees/fusiontree.jl")
44
include("fusion_trees/clebsch_gordan_tensors.jl")
5-
include("fusiontensor/fusedaxes.jl")
65
include("fusiontensor/fusiontensor.jl")
76
include("fusiontensor/base_interface.jl")
87
include("fusiontensor/array_cast.jl")

src/fusiontensor/fusedaxes.jl

Lines changed: 0 additions & 98 deletions
This file was deleted.

src/fusiontensor/fusiontensor.jl

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# This file defines struct FusionTensor and constructors
22

3-
using BlockArrays: AbstractBlockMatrix, BlockArrays, blocklength, findblock
3+
using BlockArrays: AbstractBlockMatrix, BlockArrays, BlockIndexRange, blocklength, findblock
44

5-
using BlockSparseArrays: AbstractBlockSparseMatrix, BlockSparseArray, eachblockstoredindex
5+
using BlockSparseArrays:
6+
AbstractBlockSparseMatrix, BlockSparseArray, eachblockstoredindex, to_block_indices
67
using GradedUnitRanges:
78
AbstractGradedUnitRange,
89
blocklabels,
10+
blockmergesort,
911
dual,
12+
gradedrange,
1013
isdual,
1114
map_blocklabels,
1215
sector_type,
@@ -20,7 +23,6 @@ struct FusionTensor{T,N,CoDomainAxes,DomainAxes,Mat,Mapping} <: AbstractArray{T,
2023
trees_block_mapping::Mapping
2124

2225
# inner constructor to impose constraints on types
23-
# TBD replace codomain_legs with FusedAxes(codomain_legs)?
2426
function FusionTensor(
2527
mat::AbstractMatrix,
2628
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
@@ -88,12 +90,22 @@ function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::Secto
8890
return Block(b1..., b2...)
8991
end
9092

91-
sanitize_axes(::Tuple{}) = ()
9293
function sanitize_axes(raw_legs::Tuple{Vararg{AbstractGradedUnitRange}})
9394
legs = promote_sectors(typeof(first(raw_legs)), raw_legs)
9495
@assert all(check_unique_blocklabels.(legs))
9596
return legs
9697
end
98+
sanitize_axes(::Tuple{}, ::Tuple{}) = TrivialSector, (), ()
99+
function sanitize_axes(
100+
codomain_legs_raw::Tuple{Vararg{AbstractGradedUnitRange}},
101+
domain_legs_raw::Tuple{Vararg{AbstractGradedUnitRange}},
102+
)
103+
legs = sanitize_axes((codomain_legs_raw..., domain_legs_raw...))
104+
S = sector_type(first(legs))
105+
codomain_legs = legs[begin:length(codomain_legs_raw)]
106+
domain_legs = legs[(length(codomain_legs_raw) + 1):end]
107+
return S, domain_legs, codomain_legs
108+
end
97109

98110
function check_unique_blocklabels(g::AbstractGradedUnitRange)
99111
return length(unique(blocklabels(g))) == blocklength(g)
@@ -145,33 +157,86 @@ function FusionTensor(
145157
codomain_legs_raw::Tuple{Vararg{AbstractGradedUnitRange}},
146158
domain_legs_raw::Tuple{Vararg{AbstractGradedUnitRange}},
147159
)
148-
legs = sanitize_axes((codomain_legs_raw..., domain_legs_raw...))
149-
S = sector_type(first(legs))
150-
codomain_legs = legs[begin:length(codomain_legs_raw)]
151-
domain_legs = legs[(length(codomain_legs_raw) + 1):end]
152-
codomain_fused_axes = FusedAxes{S}(codomain_legs)
153-
domain_fused_axes = FusedAxes{S}(dual.(domain_legs))
154-
mat = initialize_data_matrix(elt, codomain_fused_axes, domain_fused_axes)
155-
tree_to_block_mapping = intersect_sectors(codomain_fused_axes, domain_fused_axes)
160+
S, domain_legs, codomain_legs = sanitize_axes(codomain_legs_raw, domain_legs_raw)
161+
162+
row_axis, codomain_trees_to_ranges_mapping = fuse_axes(S, codomain_legs)
163+
nondual_col_axis, domain_trees_to_ranges_mapping = fuse_axes(S, dual.(domain_legs))
164+
165+
mat = initialize_data_matrix(elt, row_axis, nondual_col_axis)
166+
tree_to_block_mapping = intersect_sectors(
167+
codomain_trees_to_ranges_mapping, domain_trees_to_ranges_mapping
168+
)
156169
return FusionTensor(mat, codomain_legs, domain_legs, tree_to_block_mapping)
157170
end
158171

159-
function FusionTensor(elt::Type, ::Tuple{}, ::Tuple{})
160-
codomain_fused_axes = FusedAxes{TrivialSector}(())
161-
domain_fused_axes = FusedAxes{TrivialSector}(())
162-
mat = initialize_data_matrix(elt, codomain_fused_axes, domain_fused_axes)
163-
tree_to_block_mapping = intersect_sectors(codomain_fused_axes, domain_fused_axes)
164-
return FusionTensor(mat, (), (), tree_to_block_mapping)
172+
function fuse_axes(::Type{S}, ::Tuple{}) where {S<:AbstractSector}
173+
fused_axis = gradedrange([trivial(S) => 1])
174+
trees_to_ranges_mapping = Dict([SectorFusionTree{S}() => Block(1)[1:1]])
175+
return fused_axis, trees_to_ranges_mapping
176+
end
177+
function fuse_axes(::Type, outer_legs::Tuple{Vararg{AbstractGradedUnitRange}})
178+
fusion_trees_mult = fusion_trees_external_multiplicities(outer_legs)
179+
fused_leg, trees_to_ranges_mapping = compute_inner_ranges(fusion_trees_mult)
180+
return fused_leg, trees_to_ranges_mapping
181+
end
182+
183+
function fusion_trees_external_multiplicities(
184+
outer_legs::Tuple{Vararg{AbstractGradedUnitRange}}
185+
)
186+
tree_arrows = isdual.(outer_legs)
187+
return mapreduce(vcat, CartesianIndices(blocklength.(outer_legs))) do it
188+
block_sectors = map((g, i) -> blocklabels(g)[i], outer_legs, Tuple(it))
189+
block_mult = mapreduce((g, i) -> blocklengths(g)[i], *, outer_legs, Tuple(it); init=1)
190+
return build_trees(block_sectors, tree_arrows) .=> block_mult
191+
end
192+
end
193+
194+
function compute_inner_ranges(
195+
fusion_trees_mult::AbstractVector{<:Pair{<:SectorFusionTree,<:Integer}}
196+
)
197+
fused_leg = blockmergesort(
198+
gradedrange(root_sector.(first.(fusion_trees_mult)) .=> last.(fusion_trees_mult))
199+
)
200+
range_mapping = Dict{fieldtype(eltype(fusion_trees_mult), 1),typeof(Block(1)[1:1])}()
201+
fused_sectors = blocklabels(fused_leg)
202+
shifts = ones(Int, blocklength(fused_leg))
203+
for (f, m) in fusion_trees_mult
204+
s = root_sector(f)
205+
i = findfirst(==(s), fused_sectors)
206+
range_mapping[f] = Block(i)[shifts[i]:(shifts[i] + m - 1)]
207+
shifts[i] += m
208+
end
209+
return fused_leg, range_mapping
210+
end
211+
212+
function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
213+
t = (b1, b2)
214+
return Block(Block.(t))[to_block_indices.(t)...]
215+
end
216+
217+
function intersect_sectors(
218+
codomain_trees_to_ranges_mapping::Dict{<:SectorFusionTree,<:BlockIndexRange{1}},
219+
domain_trees_to_ranges_mapping::Dict{<:SectorFusionTree,<:BlockIndexRange{1}},
220+
)
221+
return Dict(
222+
map(
223+
t -> first.(t) => to_blockindexrange(last.(t)...),
224+
Iterators.filter(
225+
t -> root_sector(first(t[1])) == root_sector(first(t[2])),
226+
Iterators.product(codomain_trees_to_ranges_mapping, domain_trees_to_ranges_mapping),
227+
),
228+
),
229+
)
165230
end
166231

167232
function initialize_data_matrix(
168-
elt::Type{<:Number}, codomain_fused_axes::FusedAxes, domain_fused_axes::FusedAxes
233+
elt::Type{<:Number},
234+
mat_row_axis::AbstractGradedUnitRange,
235+
nondual_col_axis::AbstractGradedUnitRange,
169236
)
170-
mat_row_axis = fused_axis(codomain_fused_axes)
171-
mat_col_axis = dual(fused_axis(domain_fused_axes))
172237
# non-abelian fusion trees have float eltype: need compatible type
173238
promoted = promote_type(elt, fusiontree_eltype(sector_type(mat_row_axis)))
174-
mat = BlockSparseArray{promoted}(mat_row_axis, mat_col_axis)
239+
mat = BlockSparseArray{promoted}(mat_row_axis, dual(nondual_col_axis))
175240
initialize_allowed_sectors!(mat)
176241
return mat
177242
end

0 commit comments

Comments
 (0)