Skip to content

Commit 0244131

Browse files
authored
define FusionTensorAxes (#58)
1 parent 72e6907 commit 0244131

File tree

7 files changed

+373
-149
lines changed

7 files changed

+373
-149
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -21,7 +21,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"
2121
Accessors = "0.1.42"
2222
BlockArrays = "1.6"
2323
BlockSparseArrays = "0.7.4"
24-
GradedArrays = "0.4.7"
24+
GradedArrays = "0.4.13"
2525
HalfIntegers = "1.6"
2626
LRUCache = "1.6"
2727
LinearAlgebra = "1.10"

src/FusionTensors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module FusionTensors
22

33
include("fusion_trees/fusiontree.jl")
44
include("fusion_trees/clebsch_gordan_tensors.jl")
5+
include("fusiontensor/fusiontensoraxes.jl")
56
include("fusiontensor/fusiontensor.jl")
67
include("fusiontensor/base_interface.jl")
78
include("fusiontensor/array_cast.jl")

src/fusiontensor/base_interface.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function transpose_mapping(b::BlockIndexRange{2})
4444
return new_block[reverse(b.indices)...]
4545
end
4646
function Base.adjoint(ft::FusionTensor)
47-
new_axes = tuplemortar((dual.(domain_axes(ft)), dual.(codomain_axes(ft))))
47+
new_axes = FusionTensorAxes(dual.(domain_axes(ft)), dual.(codomain_axes(ft)))
4848
return FusionTensor(
4949
adjoint(data_matrix(ft)), new_axes, transpose_mapping(trees_block_mapping(ft))
5050
)
@@ -57,7 +57,7 @@ Base.conj(ft::FusionTensor{<:Real}) = ft # same object for real element type
5757
Base.conj(ft::FusionTensor) = set_data_matrix(ft, conj(data_matrix(ft)))
5858

5959
function Base.copy(ft::FusionTensor)
60-
return FusionTensor(copy(data_matrix(ft)), copy.(axes(ft)), copy(trees_block_mapping(ft)))
60+
return FusionTensor(copy(data_matrix(ft)), copy(axes(ft)), copy(trees_block_mapping(ft)))
6161
end
6262

6363
function Base.deepcopy(ft::FusionTensor)
@@ -66,8 +66,7 @@ function Base.deepcopy(ft::FusionTensor)
6666
)
6767
end
6868

69-
# eachindex is automatically defined for AbstractArray. We do not want it.
70-
Base.eachindex(::FusionTensor) = error("eachindex not defined for FusionTensor")
69+
Base.eachindex(::FusionTensor) = throw(MethodError(eachindex, (FusionTensor,)))
7170

7271
function Base.getindex(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
7372
charge_matrix = data_matrix(ft)[trees_block_mapping(ft)[f1, f2]]

src/fusiontensor/fusiontensor.jl

Lines changed: 121 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -19,136 +19,32 @@ using GradedArrays:
1919
sectormergesort,
2020
sectors,
2121
space_isequal
22-
using TensorAlgebra: BlockedTuple, tuplemortar
22+
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
2323
using TensorProducts: tensor_product
2424
using TypeParameterAccessors: type_parameters
2525

26-
struct FusionTensor{T,N,Axes,Mat<:AbstractMatrix{T},Mapping} <: AbstractArray{T,N}
27-
data_matrix::Mat
28-
axes::Axes
29-
trees_block_mapping::Mapping
30-
31-
# inner constructor to impose constraints on types
32-
function FusionTensor{T,N,Axes,Mat,Mapping}(
33-
mat, legs, trees_block_mapping
34-
) where {T,N,Axes,Mat,Mapping}
35-
S = isempty(legs) ? TrivialSector : sector_type(first(legs))
36-
@assert keytype(trees_block_mapping) <:
37-
Tuple{<:SectorFusionTree{S},<:SectorFusionTree{S}}
38-
@assert all(sector_type.(Tuple(legs)) .=== S)
39-
return new{T,N,Axes,Mat,Mapping}(mat, legs, trees_block_mapping)
40-
end
41-
end
42-
43-
function FusionTensor(
44-
mat::AbstractMatrix,
45-
legs::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractGradedUnitRange}}},
46-
trees_block_mapping::Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}},
47-
)
48-
return FusionTensor{
49-
eltype(mat),length(legs),typeof(legs),typeof(mat),typeof(trees_block_mapping)
50-
}(
51-
mat, legs, trees_block_mapping
52-
)
53-
end
54-
55-
# getters
56-
data_matrix(ft::FusionTensor) = ft.data_matrix
57-
trees_block_mapping(ft::FusionTensor) = ft.trees_block_mapping
58-
59-
# misc access
60-
codomain_axes(ft::FusionTensor) = axes(ft)[Block(1)]
61-
domain_axes(ft::FusionTensor) = axes(ft)[Block(2)]
62-
ndims_codomain(ft::FusionTensor) = length(codomain_axes(ft))
63-
ndims_domain(ft::FusionTensor) = length(domain_axes(ft))
64-
65-
dummy_axis(ft::FusionTensor) = dummy_axis(sector_type(ft))
66-
dummy_axis(::Type{S}) where {S<:AbstractSector} = gradedrange([trivial(S) => 1])
67-
68-
function codomain_axis(ft::FusionTensor)
69-
if ndims_codomain(ft) == 0
70-
return dummy_axis(ft)
71-
end
72-
return (codomain_axes(ft)...)
73-
end
74-
function domain_axis(ft::FusionTensor)
75-
if ndims_domain(ft) == 0
76-
return dummy_axis(ft)
77-
end
78-
return dual((dual.(domain_axes(ft))...))
79-
end
80-
function charge_block_size(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
81-
b = Tuple(findblock(ft, f1, f2))
82-
return ntuple(i -> sector_multiplicity(axes(ft, i)[b[i]]), ndims(ft))
83-
end
26+
# ======================================= Misc ===========================================
8427

85-
# GradedArrays interface
86-
function GradedArrays.sector_type(
87-
::Type{<:FusionTensor{<:Any,<:Any,<:Any,<:Any,<:Dict{<:Tuple{<:Any,F}}}}
88-
) where {F}
89-
return sector_type(F)
90-
end
91-
92-
# BlockArrays interface
93-
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
94-
# find outer block corresponding to fusion trees
95-
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
96-
b1 = find_sector_block.(leaves(f1), codomain_axes(ft))
97-
b2 = find_sector_block.(leaves(f2), domain_axes(ft))
98-
return Block(b1..., b2...)
99-
end
10028
# TBD move to GradedArrays? rename findfirst_sector?
10129
function find_sector_block(s::AbstractSector, g::AbstractGradedUnitRange)
10230
return findfirst(==(s), sectors(flip_dual(g)))
10331
end
10432

105-
function sanitize_axes(raw_legs::Tuple{Vararg{AbstractGradedUnitRange}})
106-
legs = promote_sectors(raw_legs)
107-
@assert all(check_unique_sectors.(legs))
108-
return legs
109-
end
110-
sanitize_axes(legs::BlockedTuple{2,(0, 0)}) = TrivialSector, legs
111-
function sanitize_axes(raw_legs::BlockedTuple{2})
112-
flat_legs = sanitize_axes(Tuple(raw_legs))
113-
return sector_type(first(flat_legs)), BlockedTuple(flat_legs, Val(blocklengths(raw_legs)))
114-
end
115-
116-
function check_unique_sectors(g::AbstractGradedUnitRange)
117-
return length(unique(sectors(g))) == blocklength(g)
118-
end
119-
120-
promote_sectors(legs::NTuple{<:Any,<:AbstractGradedUnitRange}) = legs # nothing to do
121-
function promote_sectors(legs)
122-
T = promote_sector_type(legs)
123-
# fuse with trivial to insert all missing arguments inside each GradedAxis
124-
# avoid depending on GradedArrays internals
125-
s0 = trivial(T)
126-
return map_sectors.(s -> only(sectors(to_gradedrange(tensor_product(s0, s)))), legs)
127-
end
128-
129-
function promote_sector_type(legs)
130-
# fuse trivial sectors to produce unified type
131-
# avoid depending on GradedArrays internals
132-
return sector_type(tensor_product(trivial.(legs)...))
33+
# TBD move to GradedArrays?
34+
function checkaxes(::Type{Bool}, axes1, axes2)
35+
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
13336
end
13437

135-
# initialize with already computed data_matrix
136-
function FusionTensor(
137-
x,
138-
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
139-
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
140-
)
141-
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
38+
# TBD move to GradedArrays?
39+
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
40+
function checkaxes(ax1, ax2)
41+
return checkaxes(Bool, ax1, ax2) ||
42+
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
14243
end
14344

144-
function FusionTensor(mat::AbstractMatrix, legs::BlockedTuple{2})
145-
# init with empty data_matrix to construct trees_block_mapping
146-
ft = FusionTensor(eltype(mat), legs)
147-
for b in eachblockstoredindex(mat)
148-
@assert b in eachblockstoredindex(data_matrix(ft)) # check matrix block is allowed
149-
data_matrix(ft)[b] = mat[b]
150-
end
151-
return ft
45+
function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
46+
t = (b1, b2)
47+
return Block(Block.(t))[to_block_indices.(t)...]
15248
end
15349

15450
function flip_domain(nonflipped_col_axis, nonflipped_trees_to_ranges)
@@ -159,21 +55,8 @@ function flip_domain(nonflipped_col_axis, nonflipped_trees_to_ranges)
15955
return col_axis, domain_trees_to_ranges_mapping
16056
end
16157

162-
# empty matrix
163-
function FusionTensor(elt::Type, raw_legs::BlockedTuple{2})
164-
S, legs = sanitize_axes(raw_legs)
165-
row_axis, codomain_trees_to_ranges = fuse_axes(S, legs[Block(1)])
166-
col_axis, domain_trees_to_ranges = flip_domain(fuse_axes(S, dual.(legs[Block(2)]))...)
167-
168-
mat = initialize_data_matrix(elt, row_axis, col_axis)
169-
tree_to_block_mapping = intersect_codomain_domain(
170-
codomain_trees_to_ranges, domain_trees_to_ranges
171-
)
172-
return FusionTensor(mat, legs, tree_to_block_mapping)
173-
end
174-
17558
function fuse_axes(::Type{S}, ::Tuple{}) where {S<:AbstractSector}
176-
fused_axis = gradedrange([trivial(S) => 1])
59+
fused_axis = trivial_axis(S)
17760
trees_to_ranges_mapping = Dict([SectorFusionTree{S}() => Block(1)[1:1]])
17861
return fused_axis, trees_to_ranges_mapping
17962
end
@@ -213,11 +96,6 @@ function compute_inner_ranges(fusion_trees_mult)
21396
return fused_leg, range_mapping
21497
end
21598

216-
function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
217-
t = (b1, b2)
218-
return Block(Block.(t))[to_block_indices.(t)...]
219-
end
220-
22199
function intersect_codomain_domain(
222100
codomain_trees_to_ranges_mapping::Dict{<:SectorFusionTree,<:BlockIndexRange{1}},
223101
domain_trees_to_ranges_mapping::Dict{<:SectorFusionTree,<:BlockIndexRange{1}},
@@ -257,11 +135,112 @@ function initialize_data_matrix(
257135
return mat
258136
end
259137

260-
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
261-
function checkaxes(ax1, ax2)
262-
return checkaxes(Bool, ax1, ax2) ||
263-
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
138+
# ==================================== Definitions =======================================
139+
140+
struct FusionTensor{T,N,Axes<:FusionTensorAxes,Mat<:AbstractMatrix{T},Mapping} <:
141+
AbstractArray{T,N}
142+
data_matrix::Mat
143+
axes::Axes
144+
trees_block_mapping::Mapping
145+
146+
# inner constructor to impose constraints on types
147+
function FusionTensor{T,N,Axes,Mat,Mapping}(
148+
mat, legs, trees_block_mapping
149+
) where {T,N,Axes,Mat,Mapping}
150+
S = sector_type(legs)
151+
@assert keytype(trees_block_mapping) <:
152+
Tuple{<:SectorFusionTree{S},<:SectorFusionTree{S}}
153+
return new{T,N,Axes,Mat,Mapping}(mat, legs, trees_block_mapping)
154+
end
264155
end
265-
function checkaxes(::Type{Bool}, axes1, axes2)
266-
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
156+
157+
# ===================================== Accessors ========================================
158+
159+
data_matrix(ft::FusionTensor) = ft.data_matrix
160+
trees_block_mapping(ft::FusionTensor) = ft.trees_block_mapping
161+
162+
# ==================================== Constructors ======================================
163+
164+
function FusionTensor(
165+
mat::AbstractMatrix,
166+
legs::FusionTensorAxes,
167+
trees_block_mapping::Dict{<:Tuple{<:SectorFusionTree,<:SectorFusionTree}},
168+
)
169+
return FusionTensor{
170+
eltype(mat),length(legs),typeof(legs),typeof(mat),typeof(trees_block_mapping)
171+
}(
172+
mat, legs, trees_block_mapping
173+
)
174+
end
175+
176+
# empty matrix
177+
function FusionTensor(elt::Type, legs::FusionTensorAxes)
178+
S = sector_type(legs)
179+
row_axis, codomain_trees_to_ranges = fuse_axes(S, codomain(legs))
180+
col_axis, domain_trees_to_ranges = flip_domain(fuse_axes(S, dual.(domain(legs)))...)
181+
182+
mat = initialize_data_matrix(elt, row_axis, col_axis)
183+
tree_to_block_mapping = intersect_codomain_domain(
184+
codomain_trees_to_ranges, domain_trees_to_ranges
185+
)
186+
return FusionTensor(mat, legs, tree_to_block_mapping)
187+
end
188+
189+
#constructor from precomputed data_matrix
190+
function FusionTensor(mat::AbstractMatrix, legs::FusionTensorAxes)
191+
# init with empty data_matrix to construct trees_block_mapping
192+
ft = FusionTensor(eltype(mat), legs)
193+
for b in eachblockstoredindex(mat)
194+
b in eachblockstoredindex(data_matrix(ft)) ||
195+
throw(ArgumentError("matrix block $b is not allowed"))
196+
data_matrix(ft)[b] = mat[b]
197+
end
198+
return ft
199+
end
200+
201+
FusionTensor(x, legs::BlockedTuple{2}) = FusionTensor(x, FusionTensorAxes(legs))
202+
203+
# constructor from split axes
204+
function FusionTensor(
205+
x,
206+
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
207+
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
208+
)
209+
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
210+
end
211+
212+
# ================================ BlockArrays interface =================================
213+
214+
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
215+
# find outer block corresponding to fusion trees
216+
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
217+
b1 = find_sector_block.(leaves(f1), codomain_axes(ft))
218+
b2 = find_sector_block.(leaves(f2), domain_axes(ft))
219+
return Block(b1..., b2...)
220+
end
221+
222+
# ============================== GradedArrays interface ==================================
223+
224+
function GradedArrays.sector_type(::Type{FT}) where {FT<:FusionTensor}
225+
return sector_type(type_parameters(FT, 3))
226+
end
227+
228+
# ============================== FusionTensor interface ==================================
229+
230+
# misc access
231+
codomain_axes(ft::FusionTensor) = codomain(axes(ft))
232+
233+
domain_axes(ft::FusionTensor) = domain(axes(ft))
234+
235+
codomain_axis(ft::FusionTensor) = fused_codomain(axes(ft))
236+
237+
domain_axis(ft::FusionTensor) = fused_domain(axes(ft))
238+
239+
ndims_codomain(ft::FusionTensor) = length_codomain(axes(ft))
240+
241+
ndims_domain(ft::FusionTensor) = length_domain(axes(ft))
242+
243+
function charge_block_size(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
244+
b = Tuple(findblock(ft, f1, f2))
245+
return ntuple(i -> sector_multiplicity(axes(ft, i)[b[i]]), ndims(ft))
267246
end

0 commit comments

Comments
 (0)