Skip to content

Commit 6ecb67c

Browse files
committed
use GradedArrays functions
1 parent c345a71 commit 6ecb67c

File tree

5 files changed

+33
-49
lines changed

5 files changed

+33
-49
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"
2222
Accessors = "0.1.42"
2323
BlockArrays = "1.6"
2424
BlockSparseArrays = "0.7.4"
25-
GradedArrays = "0.4.13"
25+
GradedArrays = "0.4.14"
2626
HalfIntegers = "1.6"
2727
LRUCache = "1.6"
2828
LinearAlgebra = "1.10"

src/fusiontensor/base_interface.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using Accessors: @set
44
using BlockSparseArrays: @view!, eachstoredblock
5+
using GradedArrays: checkspaces, checkspaces_dual
56
using TensorAlgebra: BlockedTuple, tuplemortar
67

78
set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix
@@ -11,21 +12,21 @@ Base.:*(ft::FusionTensor, x::Number) = x * ft
1112

1213
# tensor contraction is a block data_matrix product.
1314
function Base.:*(left::FusionTensor, right::FusionTensor)
14-
checkaxes_dual(domain_axes(left), codomain_axes(right))
15+
checkspaces_dual(domain_axes(left), codomain_axes(right))
1516
new_data_matrix = data_matrix(left) * data_matrix(right)
1617
return FusionTensor(new_data_matrix, codomain_axes(left), domain_axes(right))
1718
end
1819

1920
# tensor addition is a block data_matrix add.
2021
function Base.:+(left::FusionTensor, right::FusionTensor)
21-
checkaxes(axes(left), axes(right))
22+
checkspaces(axes(left), axes(right))
2223
return set_data_matrix(left, data_matrix(left) + data_matrix(right))
2324
end
2425

2526
Base.:-(ft::FusionTensor) = set_data_matrix(ft, -data_matrix(ft))
2627

2728
function Base.:-(left::FusionTensor, right::FusionTensor)
28-
checkaxes(axes(left), axes(right))
29+
checkspaces(axes(left), axes(right))
2930
return set_data_matrix(left, data_matrix(left) - data_matrix(right))
3031
end
3132

src/fusiontensor/fusiontensor.jl

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using GradedArrays:
99
SectorProduct,
1010
TrivialSector,
1111
dual,
12+
findfirstblock,
1213
flip,
1314
flip_dual,
1415
gradedrange,
@@ -27,23 +28,6 @@ using TypeParameterAccessors: type_parameters
2728

2829
# ======================================= Misc ===========================================
2930

30-
# TBD move to GradedArrays? rename findfirst_sector?
31-
function find_sector_block(s::AbstractSector, g::AbstractGradedUnitRange)
32-
return findfirst(==(s), sectors(flip_dual(g)))
33-
end
34-
35-
# TBD move to GradedArrays?
36-
function checkaxes(::Type{Bool}, axes1, axes2)
37-
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
38-
end
39-
40-
# TBD move to GradedArrays?
41-
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
42-
function checkaxes(ax1, ax2)
43-
return checkaxes(Bool, ax1, ax2) ||
44-
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
45-
end
46-
4731
function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
4832
t = (b1, b2)
4933
return Block(Block.(t))[to_block_indices.(t)...]
@@ -260,9 +244,9 @@ end
260244
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
261245
# find outer block corresponding to fusion trees
262246
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
263-
b1 = find_sector_block.(leaves(f1), codomain_axes(ft))
264-
b2 = find_sector_block.(leaves(f2), domain_axes(ft))
265-
return Block(b1..., b2...)
247+
b1 = findfirstblock.(codomain_axes(ft), leaves(f1))
248+
b2 = findfirstblock.(domain_axes(ft), leaves(f2))
249+
return Block(Int.(b1)..., Int.(b2)...)
266250
end
267251

268252
# ============================== GradedArrays interface ==================================

test/setup.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ using FusionTensors:
66
codomain_axes,
77
data_matrix,
88
domain_axes,
9-
checkaxes,
10-
checkaxes_dual,
119
domain_axis,
1210
codomain_axis,
1311
ndims_codomain,
1412
ndims_domain
15-
using GradedArrays: dual, sectors, sector_multiplicities, space_isequal
13+
using GradedArrays:
14+
checkspaces, checkspaces_dual, dual, sectors, sector_multiplicities, space_isequal
1615

1716
function check_sanity(ft::FusionTensor)
1817
nca = ndims_domain(ft)
@@ -25,8 +24,8 @@ function check_sanity(ft::FusionTensor)
2524
@assert nda + nca == ndims(ft) "invalid ndims"
2625

2726
@assert length(axes(ft)) == ndims(ft) "ndims does not match axes"
28-
checkaxes(axes(ft)[begin:nda], codomain_axes(ft))
29-
checkaxes(axes(ft)[(nda + 1):end], domain_axes(ft))
27+
checkspaces(axes(ft)[begin:nda], codomain_axes(ft))
28+
checkspaces(axes(ft)[(nda + 1):end], domain_axes(ft))
3029

3130
m = data_matrix(ft)
3231
@assert ndims(m) == 2 "invalid data_matrix ndims"

test/test_basics.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ using FusionTensors:
99
data_matrix,
1010
domain_axes,
1111
FusionTensor,
12-
checkaxes,
13-
checkaxes_dual,
1412
codomain_axis,
1513
domain_axis,
1614
ndims_domain,
@@ -21,6 +19,8 @@ using GradedArrays:
2119
SectorProduct,
2220
TrivialSector,
2321
Z,
22+
checkspaces,
23+
checkspaces_dual,
2424
dual,
2525
flip,
2626
gradedrange,
@@ -54,11 +54,11 @@ include("setup.jl")
5454

5555
# getters
5656
@test data_matrix(ft1) == m
57-
@test checkaxes(axes(ft1), tuplemortar(((g1,), (g2,))))
57+
@test checkspaces(axes(ft1), tuplemortar(((g1,), (g2,))))
5858

5959
# misc
60-
@test checkaxes(codomain_axes(ft1), (g1,))
61-
@test checkaxes(domain_axes(ft1), (g2,))
60+
@test checkspaces(codomain_axes(ft1), (g1,))
61+
@test checkspaces(domain_axes(ft1), (g2,))
6262
@test ndims_codomain(ft1) == 1
6363
@test ndims_domain(ft1) == 1
6464
@test size(data_matrix(ft1)) == (6, 5)
@@ -86,42 +86,42 @@ include("setup.jl")
8686
@test ft2 !== ft1
8787
@test data_matrix(ft2) == data_matrix(ft1)
8888
@test data_matrix(ft2) !== data_matrix(ft1)
89-
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
90-
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
89+
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
90+
@test checkspaces(domain_axes(ft2), domain_axes(ft1))
9191

9292
ft2 = deepcopy(ft1)
9393
@test ft2 !== ft1
9494
@test data_matrix(ft2) == data_matrix(ft1)
9595
@test data_matrix(ft2) !== data_matrix(ft1)
96-
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
97-
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
96+
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
97+
@test checkspaces(domain_axes(ft2), domain_axes(ft1))
9898

9999
# similar
100100
ft2 = similar(ft1)
101101
@test isnothing(check_sanity(ft2))
102102
@test eltype(ft2) == Float64
103-
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
104-
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
103+
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
104+
@test checkspaces(domain_axes(ft2), domain_axes(ft1))
105105

106106
ft3 = similar(ft1, ComplexF64)
107107
@test isnothing(check_sanity(ft3))
108108
@test eltype(ft3) == ComplexF64
109-
@test checkaxes(codomain_axes(ft3), codomain_axes(ft1))
110-
@test checkaxes(domain_axes(ft3), domain_axes(ft1))
109+
@test checkspaces(codomain_axes(ft3), codomain_axes(ft1))
110+
@test checkspaces(domain_axes(ft3), domain_axes(ft1))
111111

112112
@test_throws AssertionError similar(ft1, Int)
113113

114114
ft5 = similar(ft1, ComplexF32, ((g1, g1), (g2,)))
115115
@test isnothing(check_sanity(ft5))
116116
@test eltype(ft5) == ComplexF64
117-
@test checkaxes(codomain_axes(ft5), (g1, g1))
118-
@test checkaxes(domain_axes(ft5), (g2,))
117+
@test checkspaces(codomain_axes(ft5), (g1, g1))
118+
@test checkspaces(domain_axes(ft5), (g2,))
119119

120120
ft5 = similar(ft1, ComplexF32, tuplemortar(((g1, g1), (g2,))))
121121
@test isnothing(check_sanity(ft5))
122122
@test eltype(ft5) == ComplexF64
123-
@test checkaxes(codomain_axes(ft5), (g1, g1))
124-
@test checkaxes(domain_axes(ft5), (g2,))
123+
@test checkspaces(codomain_axes(ft5), (g1, g1))
124+
@test checkspaces(domain_axes(ft5), (g2,))
125125
end
126126

127127
@testset "More than 2 axes" begin
@@ -135,8 +135,8 @@ end
135135
ft = FusionTensor(m2, (g1, g2), (g3, g4))
136136

137137
@test data_matrix(ft) == m2
138-
@test checkaxes(codomain_axes(ft), (g1, g2))
139-
@test checkaxes(domain_axes(ft), (g3, g4))
138+
@test checkspaces(codomain_axes(ft), (g1, g2))
139+
@test checkspaces(domain_axes(ft), (g3, g4))
140140

141141
@test axes(ft) == FusionTensorAxes(tuplemortar(((g1, g2), (g3, g4))))
142142
@test ndims_codomain(ft) == 2
@@ -271,7 +271,7 @@ end
271271
ft7 = FusionTensor{Float64}(undef, (g1,), (g2, g3, g4))
272272
@test_throws DimensionMismatch ft7 + ft3
273273
@test_throws DimensionMismatch ft7 - ft3
274-
@test_throws DimensionMismatch ft7 * ft3
274+
@test_throws ArgumentError ft7 * ft3
275275
end
276276

277277
@testset "specific constructors" begin

0 commit comments

Comments
 (0)