Skip to content

Commit f6ae180

Browse files
committed
use BlockedTuple
1 parent d3b624e commit f6ae180

File tree

7 files changed

+99
-100
lines changed

7 files changed

+99
-100
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ LinearAlgebra = "1.10.0"
3333
SparseArraysBase = "0.2.0"
3434
Strided = "2.2.0"
3535
SymmetrySectors = "0.1.1"
36-
TensorAlgebra = "0.1.0"
36+
TensorAlgebra = "0.1.5"
3737
TypeParameterAccessors = "0.2.0"
3838
WignerSymbols = "2.0.0"
3939
julia = "1.10"

src/fusiontensor/base_interface.jl

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Accessors: @set
44

55
using BlockSparseArrays: @view!
6+
using TensorAlgebra: BlockedTuple, tuplemortar
67

78
set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix
89

@@ -43,35 +44,25 @@ function transpose_mapping(b::BlockIndexRange{2})
4344
return new_block[reverse(b.indices)...]
4445
end
4546
function Base.adjoint(ft::FusionTensor)
47+
new_axes = tuplemortar((dual.(domain_axes(ft)), dual.(codomain_axes(ft))))
4648
return FusionTensor(
47-
adjoint(data_matrix(ft)),
48-
dual.(domain_axes(ft)),
49-
dual.(codomain_axes(ft)),
50-
transpose_mapping(trees_block_mapping(ft)),
49+
adjoint(data_matrix(ft)), new_axes, transpose_mapping(trees_block_mapping(ft))
5150
)
5251
end
5352

54-
Base.axes(ft::FusionTensor) = (codomain_axes(ft)..., domain_axes(ft)...)
53+
Base.axes(ft::FusionTensor) = ft.axes
5554

5655
# conj is defined as coefficient wise complex conjugation, without axis dual
5756
Base.conj(ft::FusionTensor{<:Real}) = ft # same object for real element type
5857
Base.conj(ft::FusionTensor) = set_data_matrix(ft, conj(data_matrix(ft)))
5958

6059
function Base.copy(ft::FusionTensor)
61-
return FusionTensor(
62-
copy(data_matrix(ft)),
63-
copy.(codomain_axes(ft)),
64-
copy.(domain_axes(ft)),
65-
copy(trees_block_mapping(ft)),
66-
)
60+
return FusionTensor(copy(data_matrix(ft)), copy.(axes(ft)), copy(trees_block_mapping(ft)))
6761
end
6862

6963
function Base.deepcopy(ft::FusionTensor)
7064
return FusionTensor(
71-
deepcopy(data_matrix(ft)),
72-
deepcopy.(codomain_axes(ft)),
73-
deepcopy.(domain_axes(ft)),
74-
deepcopy(trees_block_mapping(ft)),
65+
deepcopy(data_matrix(ft)), deepcopy(axes(ft)), deepcopy(trees_block_mapping(ft))
7566
)
7667
end
7768

@@ -91,16 +82,33 @@ end
9182

9283
Base.permutedims(ft::FusionTensor, args...) = fusiontensor_permutedims(ft, args...)
9384

94-
function Base.similar(ft::FusionTensor, ::Type{T}) where {T}
85+
Base.similar(ft::FusionTensor) = similar(ft, eltype(ft))
86+
function Base.similar(ft::FusionTensor, T::Type)
87+
# reuse trees_block_mapping
88+
9589
# some fusion trees have Float64 eltype: need compatible type
9690
@assert promote_type(T, fusiontree_eltype(sector_type(ft))) === T
9791
mat = similar(data_matrix(ft), T)
9892
initialize_allowed_sectors!(mat)
99-
return FusionTensor(mat, codomain_axes(ft), domain_axes(ft), trees_block_mapping(ft))
93+
return FusionTensor(mat, axes(ft), trees_block_mapping(ft))
94+
end
95+
96+
# trigger explicit error in TensorAlgebra.contract
97+
# TBD impose some convention? Remove?
98+
function Base.similar(
99+
ft::FusionTensor, T::Type, new_axes::Tuple{Vararg{AbstractGradedUnitRange}}
100+
)
101+
throw(DimensionMismatch("Need bituple of axes"))
102+
end
103+
function Base.similar(ft::FusionTensor, T::Type, new_axes::Tuple{})
104+
throw(DimensionMismatch("Need bituple of axes"))
100105
end
101106

102-
function Base.similar(::FusionTensor, ::Type{T}, new_axes::Tuple{<:Tuple,<:Tuple}) where {T}
103-
return FusionTensor(T, new_axes[1], new_axes[2])
107+
function Base.similar(ft::FusionTensor, T::Type, new_axes::Tuple{<:Tuple,<:Tuple})
108+
return similar(ft, T, tuplemortar(new_axes))
109+
end
110+
function Base.similar(::FusionTensor, T::Type, new_axes::BlockedTuple{2})
111+
return FusionTensor(T, new_axes)
104112
end
105113

106114
Base.show(io::IO, ft::FusionTensor) = print(io, "$(ndims(ft))-dim FusionTensor")

src/fusiontensor/fusiontensor.jl

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,39 @@ using GradedUnitRanges:
1515
sector_type,
1616
space_isequal
1717
using SymmetrySectors: SectorProduct, TrivialSector
18+
using TensorAlgebra: BlockedTuple, tuplemortar
1819

19-
struct FusionTensor{T,N,CoDomainAxes,DomainAxes,Mat,Mapping} <: AbstractArray{T,N}
20+
struct FusionTensor{T,N,Axes,Mat,Mapping} <: AbstractArray{T,N}
2021
data_matrix::Mat
21-
codomain_axes::CoDomainAxes
22-
domain_axes::DomainAxes
22+
axes::Axes
2323
trees_block_mapping::Mapping
2424

2525
# inner constructor to impose constraints on types
2626
function FusionTensor(
2727
mat::AbstractMatrix,
28-
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
29-
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
28+
legs::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractGradedUnitRange}}},
3029
trees_block_mapping::Dict,
3130
)
3231
S = sector_type(axes(mat, 1))
3332
@assert sector_type(axes(mat, 2)) === S
3433
@assert keytype(trees_block_mapping) <:
3534
Tuple{<:SectorFusionTree{S},<:SectorFusionTree{S}}
36-
@assert all(sector_type.(codomain_legs) .=== S)
37-
@assert all(sector_type.(domain_legs) .=== S)
35+
@assert all(sector_type.(Tuple(legs)) .=== S)
3836
return new{
39-
eltype(mat),
40-
length(codomain_legs) + length(domain_legs),
41-
typeof(codomain_legs),
42-
typeof(domain_legs),
43-
typeof(mat),
44-
typeof(trees_block_mapping),
37+
eltype(mat),length(legs),typeof(legs),typeof(mat),typeof(trees_block_mapping)
4538
}(
46-
mat, codomain_legs, domain_legs, trees_block_mapping
39+
mat, legs, trees_block_mapping
4740
)
4841
end
4942
end
5043

5144
# getters
5245
data_matrix(ft::FusionTensor) = ft.data_matrix
53-
codomain_axes(ft::FusionTensor) = ft.codomain_axes
54-
domain_axes(ft::FusionTensor) = ft.domain_axes
5546
trees_block_mapping(ft::FusionTensor) = ft.trees_block_mapping
5647

5748
# misc access
49+
codomain_axes(ft::FusionTensor) = first(blocks(axes(ft)))
50+
domain_axes(ft::FusionTensor) = last(blocks(axes(ft)))
5851
ndims_codomain(ft::FusionTensor) = length(codomain_axes(ft))
5952
ndims_domain(ft::FusionTensor) = length(domain_axes(ft))
6053

@@ -68,7 +61,7 @@ end
6861

6962
# GradedUnitRanges interface
7063
function GradedUnitRanges.sector_type(
71-
::Type{<:FusionTensor{<:Any,<:Any,<:Any,<:Any,<:Any,<:Dict{<:Tuple{<:Any,F}}}}
64+
::Type{<:FusionTensor{<:Any,<:Any,<:Any,<:Any,<:Dict{<:Tuple{<:Any,F}}}}
7265
) where {F}
7366
return sector_type(F)
7467
end
@@ -91,16 +84,10 @@ function sanitize_axes(raw_legs::Tuple{Vararg{AbstractGradedUnitRange}})
9184
@assert all(check_unique_blocklabels.(legs))
9285
return legs
9386
end
94-
sanitize_axes(::Tuple{}, ::Tuple{}) = TrivialSector, (), ()
95-
function sanitize_axes(
96-
codomain_legs_raw::Tuple{Vararg{AbstractGradedUnitRange}},
97-
domain_legs_raw::Tuple{Vararg{AbstractGradedUnitRange}},
98-
)
99-
legs = sanitize_axes((codomain_legs_raw..., domain_legs_raw...))
100-
S = sector_type(first(legs))
101-
codomain_legs = legs[begin:length(codomain_legs_raw)]
102-
domain_legs = legs[(length(codomain_legs_raw) + 1):end]
103-
return S, domain_legs, codomain_legs
87+
sanitize_axes(legs::BlockedTuple{2,(0, 0)}) = TrivialSector, legs
88+
function sanitize_axes(raw_legs::BlockedTuple{2})
89+
flat_legs = sanitize_axes(Tuple(raw_legs))
90+
return sector_type(first(flat_legs)), BlockedTuple(flat_legs, Val(blocklengths(raw_legs)))
10491
end
10592

10693
function check_unique_blocklabels(g::AbstractGradedUnitRange)
@@ -131,12 +118,16 @@ end
131118

132119
# initialize with already computed data_matrix
133120
function FusionTensor(
134-
mat::AbstractMatrix,
121+
x,
135122
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
136123
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
137124
)
125+
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
126+
end
127+
128+
function FusionTensor(mat::AbstractMatrix, legs::BlockedTuple{2})
138129
# init with empty data_matrix to construct trees_block_mapping
139-
ft = FusionTensor(eltype(mat), codomain_legs, domain_legs)
130+
ft = FusionTensor(eltype(mat), legs)
140131
@assert space_isequal(matrix_row_axis(ft), axes(mat, 1))
141132
@assert space_isequal(matrix_column_axis(ft), axes(mat, 2))
142133
for b in eachblockstoredindex(mat)
@@ -147,21 +138,17 @@ function FusionTensor(
147138
end
148139

149140
# empty matrix
150-
function FusionTensor(
151-
elt::Type,
152-
codomain_legs_raw::Tuple{Vararg{AbstractGradedUnitRange}},
153-
domain_legs_raw::Tuple{Vararg{AbstractGradedUnitRange}},
154-
)
155-
S, domain_legs, codomain_legs = sanitize_axes(codomain_legs_raw, domain_legs_raw)
141+
function FusionTensor(elt::Type, raw_legs::BlockedTuple{2})
142+
S, legs = sanitize_axes(raw_legs)
156143

157-
row_axis, codomain_trees_to_ranges_mapping = fuse_axes(S, codomain_legs)
158-
nondual_col_axis, domain_trees_to_ranges_mapping = fuse_axes(S, dual.(domain_legs))
144+
row_axis, codomain_trees_to_ranges_mapping = fuse_axes(S, first(blocks(legs)))
145+
nondual_col_axis, domain_trees_to_ranges_mapping = fuse_axes(S, dual.(last(blocks(legs))))
159146

160147
mat = initialize_data_matrix(elt, row_axis, nondual_col_axis)
161148
tree_to_block_mapping = intersect_codomain_domain(
162149
codomain_trees_to_ranges_mapping, domain_trees_to_ranges_mapping
163150
)
164-
return FusionTensor(mat, codomain_legs, domain_legs, tree_to_block_mapping)
151+
return FusionTensor(mat, legs, tree_to_block_mapping)
165152
end
166153

167154
function fuse_axes(::Type{S}, ::Tuple{}) where {S<:AbstractSector}
@@ -178,14 +165,19 @@ end
178165
function fusion_trees_external_multiplicities(
179166
outer_legs::Tuple{Vararg{AbstractGradedUnitRange}}
180167
)
181-
tree_arrows = isdual.(outer_legs)
182168
return mapreduce(vcat, CartesianIndices(blocklength.(outer_legs))) do it
183-
block_sectors = map((g, i) -> blocklabels(g)[i], outer_legs, Tuple(it))
184-
block_mult = mapreduce((g, i) -> blocklengths(g)[i], *, outer_legs, Tuple(it); init=1)
185-
return build_trees(block_sectors, tree_arrows) .=> block_mult
169+
return fusion_trees_external_multiplicities(outer_legs, Tuple(it))
186170
end
187171
end
188172

173+
function fusion_trees_external_multiplicities(
174+
outer_legs::NTuple{N,AbstractGradedUnitRange}, indices::NTuple{N,Int}
175+
) where {N}
176+
block_sectors = map((g, i) -> blocklabels(g)[i], outer_legs, indices)
177+
block_mult = mapreduce((g, i) -> blocklengths(g)[i], *, outer_legs, indices; init=1)
178+
return build_trees(block_sectors, isdual.(outer_legs)) .=> block_mult
179+
end
180+
189181
function compute_inner_ranges(
190182
fusion_trees_mult::AbstractVector{<:Pair{<:SectorFusionTree,<:Integer}}
191183
)

src/fusiontensor/tensor_algebra_interface.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,40 @@ using BlockArrays: Block
66

77
using TensorAlgebra: BlockedPermutation, Matricize, TensorAlgebra
88

9-
# TBD how to deal with inner contraction = no ouput axis?
9+
# TODO how to deal with inner contraction = no ouput axis?
10+
# => currently biperm_dest is a BlockedPermutation{0}, change this
1011
function TensorAlgebra.allocate_output(
1112
::typeof(contract),
1213
biperm_dest::BlockedPermutation{2},
13-
a1::FusionTensor{T1,N},
14-
biperm1::BlockedPermutation{2,N},
15-
a2::FusionTensor{T2,M},
16-
biperm2::BlockedPermutation{2,M},
14+
a1::FusionTensor,
15+
biperm1::BlockedPermutation{2},
16+
a2::FusionTensor,
17+
biperm2::BlockedPermutation{2},
1718
α::Number=true,
18-
) where {T1,T2,N,M}
19+
)
1920
axes_dest = (
20-
(i -> axes(a1)[i]).(biperm1[Block(1)]), (i -> axes(a2)[i]).(biperm2[Block(2)])
21+
map(i -> axes(a1)[i], first(blocks(biperm1))),
22+
map(i -> axes(a2)[i], last(blocks(biperm2))),
2123
)
2224
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
2325
end
2426

25-
# TBD do really I need to defined these as I cannot use them in contract! and has to redefine it?
26-
# TensorAlgebra.fusedims(ft::FusionTensor, perm::BlockedPermutation) = permutedims(ft, perm)
27-
# function TensorAlgebra.splitdims(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)
28-
# function TensorAlgebra.splitdims!(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)
27+
# TBD do really I need to define these as I cannot use them in contract! and has to redefine it?
28+
#TensorAlgebra.fusedims(ft::FusionTensor, perm::BlockedPermutation{2}) = permutedims(ft, perm)
29+
#function TensorAlgebra.splitdims(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)
30+
#function TensorAlgebra.splitdims!(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)
2931

3032
# I cannot use contract! from TensorAlgebra/src/contract/contract_matricize/contract.jl
3133
# as it calls _mul!, which I should not overload.
32-
# TBD I can also overload higher up and do not allow use of different algorithms
34+
# TBD define fallback _mul!(::AbstractArray, ::AbstractArray, ::AbstractArray) in TensorAlgebra?
3335
function TensorAlgebra.contract!(
34-
alg::Matricize,
36+
::Matricize,
3537
a_dest::FusionTensor,
36-
biperm_dest::BlockedPermutation,
38+
::BlockedPermutation{2},
3739
a1::FusionTensor,
38-
biperm1::BlockedPermutation,
40+
biperm1::BlockedPermutation{2},
3941
a2::FusionTensor,
40-
biperm2::BlockedPermutation,
42+
biperm2::BlockedPermutation{2},
4143
α::Number,
4244
β::Number,
4345
)

test/basics/test_array_cast.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,13 @@ end
201201
if VERSION < v"1.11"
202202
@test_broken to_fusiontensor(zerodim, (), ()) isa FusionTensor # https://github.com/JuliaLang/julia/issues/52615
203203
else
204-
# TODO fix: add specialized method, maybe fix TensorAlgebra
205-
@test_broken to_fusiontensor(zerodim, (), ())
206-
#@test ft isa FusionTensor
207-
#@test ndims(ft) == 0
208-
#@test isnothing(check_sanity(ft))
209-
#@test size(data_matrix(ft)) == (1, 1)
210-
#@test data_matrix(ft)[1, 1] ≈ 1.0
211-
#@test_broken Array(ft) ≈ zerodim # cannot create zerodim BlockSparseArray
204+
ft = to_fusiontensor(zerodim, (), ())
205+
@test ft isa FusionTensor
206+
@test ndims(ft) == 0
207+
@test isnothing(check_sanity(ft))
208+
@test size(data_matrix(ft)) == (1, 1)
209+
@test data_matrix(ft)[1, 1] 1.0
210+
@test_broken Array(ft) zerodim # https://github.com/ITensor/BlockSparseArrays.jl/issues/27
212211
end
213212
end
214213
end

0 commit comments

Comments
 (0)