Skip to content

Commit 3e9ee17

Browse files
authored
Adapt to Gradedarrays (#61)
1 parent 1d4ef1a commit 3e9ee17

File tree

9 files changed

+64
-61
lines changed

9 files changed

+64
-61
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"
2222
Accessors = "0.1.42"
2323
BlockArrays = "1.6"
2424
BlockSparseArrays = "0.7.4"
25-
GradedArrays = "0.4.13"
25+
GradedArrays = "0.4.14"
2626
HalfIntegers = "1.6"
2727
LRUCache = "1.6"
2828
LinearAlgebra = "1.10"

src/fusiontensor/base_interface.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using Accessors: @set
44
using BlockSparseArrays: @view!, eachstoredblock
5+
using GradedArrays: checkspaces, checkspaces_dual
56
using TensorAlgebra: BlockedTuple, tuplemortar
67

78
set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix
@@ -11,21 +12,21 @@ Base.:*(ft::FusionTensor, x::Number) = x * ft
1112

1213
# tensor contraction is a block data_matrix product.
1314
function Base.:*(left::FusionTensor, right::FusionTensor)
14-
checkaxes_dual(domain_axes(left), codomain_axes(right))
15+
checkspaces_dual(domain_axes(left), codomain_axes(right))
1516
new_data_matrix = data_matrix(left) * data_matrix(right)
1617
return FusionTensor(new_data_matrix, codomain_axes(left), domain_axes(right))
1718
end
1819

1920
# tensor addition is a block data_matrix add.
2021
function Base.:+(left::FusionTensor, right::FusionTensor)
21-
checkaxes(axes(left), axes(right))
22+
checkspaces(axes(left), axes(right))
2223
return set_data_matrix(left, data_matrix(left) + data_matrix(right))
2324
end
2425

2526
Base.:-(ft::FusionTensor) = set_data_matrix(ft, -data_matrix(ft))
2627

2728
function Base.:-(left::FusionTensor, right::FusionTensor)
28-
checkaxes(axes(left), axes(right))
29+
checkspaces(axes(left), axes(right))
2930
return set_data_matrix(left, data_matrix(left) - data_matrix(right))
3031
end
3132

src/fusiontensor/fusiontensor.jl

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using GradedArrays:
99
SectorProduct,
1010
TrivialSector,
1111
dual,
12+
findfirstblock,
1213
flip,
1314
flip_dual,
1415
gradedrange,
@@ -27,23 +28,6 @@ using TypeParameterAccessors: type_parameters
2728

2829
# ======================================= Misc ===========================================
2930

30-
# TBD move to GradedArrays? rename findfirst_sector?
31-
function find_sector_block(s::AbstractSector, g::AbstractGradedUnitRange)
32-
return findfirst(==(s), sectors(flip_dual(g)))
33-
end
34-
35-
# TBD move to GradedArrays?
36-
function checkaxes(::Type{Bool}, axes1, axes2)
37-
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
38-
end
39-
40-
# TBD move to GradedArrays?
41-
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
42-
function checkaxes(ax1, ax2)
43-
return checkaxes(Bool, ax1, ax2) ||
44-
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
45-
end
46-
4731
function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
4832
t = (b1, b2)
4933
return Block(Block.(t))[to_block_indices.(t)...]
@@ -260,9 +244,9 @@ end
260244
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
261245
# find outer block corresponding to fusion trees
262246
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
263-
b1 = find_sector_block.(leaves(f1), codomain_axes(ft))
264-
b2 = find_sector_block.(leaves(f2), domain_axes(ft))
265-
return Block(b1..., b2...)
247+
b1 = findfirstblock.(flip_dual.(codomain_axes(ft)), leaves(f1))
248+
b2 = findfirstblock.(flip_dual.(domain_axes(ft)), leaves(f2))
249+
return Block(Int.(b1)..., Int.(b2)...)
266250
end
267251

268252
# ============================== GradedArrays interface ==================================

src/fusiontensor/fusiontensoraxes.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ function GradedArrays.sector_type(::Type{FTA}) where {BT,FTA<:FusionTensorAxes{B
110110
return sector_type(type_parameters(type_parameters(BT, 3), 1))
111111
end
112112

113+
function GradedArrays.checkspaces(
114+
::Type{Bool}, left::FusionTensorAxes, right::FusionTensorAxes
115+
)
116+
return left == right
117+
end
118+
113119
# ============================== FusionTensor interface ==================================
114120

115121
codomain(fta::FusionTensorAxes) = fta[Block(1)]

src/fusiontensor/linear_algebra_interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using LinearAlgebra: LinearAlgebra, mul!, norm, tr
55
using BlockArrays: Block, blocks
66

77
using BlockSparseArrays: eachblockstoredindex
8-
using GradedArrays: quantum_dimension, sectors
8+
using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors
99

1010
# allow to contract with different eltype and let BlockSparseArray ensure compatibility
1111
# impose matching type and number of axes at compile time
@@ -27,9 +27,9 @@ function LinearAlgebra.mul!(
2727
end
2828

2929
# input validation
30-
checkaxes_dual(domain_axes(A), codomain_axes(B))
31-
checkaxes(codomain_axes(C), codomain_axes(A))
32-
checkaxes(domain_axes(C), domain_axes(B))
30+
checkspaces_dual(domain_axes(A), codomain_axes(B))
31+
checkspaces(codomain_axes(C), codomain_axes(A))
32+
checkspaces(domain_axes(C), domain_axes(B))
3333
mul!(data_matrix(C), data_matrix(A), data_matrix(B), α, β)
3434
return C
3535
end

test/setup.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ using FusionTensors:
66
codomain_axes,
77
data_matrix,
88
domain_axes,
9-
checkaxes,
10-
checkaxes_dual,
119
domain_axis,
1210
codomain_axis,
1311
ndims_codomain,
1412
ndims_domain
15-
using GradedArrays: dual, sectors, sector_multiplicities, space_isequal
13+
using GradedArrays:
14+
checkspaces, checkspaces_dual, dual, sectors, sector_multiplicities, space_isequal
1615

1716
function check_sanity(ft::FusionTensor)
1817
nca = ndims_domain(ft)
@@ -25,8 +24,8 @@ function check_sanity(ft::FusionTensor)
2524
@assert nda + nca == ndims(ft) "invalid ndims"
2625

2726
@assert length(axes(ft)) == ndims(ft) "ndims does not match axes"
28-
checkaxes(axes(ft)[begin:nda], codomain_axes(ft))
29-
checkaxes(axes(ft)[(nda + 1):end], domain_axes(ft))
27+
checkspaces(axes(ft)[begin:nda], codomain_axes(ft))
28+
checkspaces(axes(ft)[(nda + 1):end], domain_axes(ft))
3029

3130
m = data_matrix(ft)
3231
@assert ndims(m) == 2 "invalid data_matrix ndims"

test/test_basics.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ using FusionTensors:
99
data_matrix,
1010
domain_axes,
1111
FusionTensor,
12-
checkaxes,
13-
checkaxes_dual,
1412
codomain_axis,
1513
domain_axis,
1614
ndims_domain,
@@ -21,6 +19,8 @@ using GradedArrays:
2119
SectorProduct,
2220
TrivialSector,
2321
Z,
22+
checkspaces,
23+
checkspaces_dual,
2424
dual,
2525
flip,
2626
gradedrange,
@@ -54,11 +54,11 @@ include("setup.jl")
5454

5555
# getters
5656
@test data_matrix(ft1) == m
57-
@test checkaxes(axes(ft1), tuplemortar(((g1,), (g2,))))
57+
@test axes(ft1) == FusionTensorAxes((g1,), (g2,))
5858

5959
# misc
60-
@test checkaxes(codomain_axes(ft1), (g1,))
61-
@test checkaxes(domain_axes(ft1), (g2,))
60+
@test checkspaces(codomain_axes(ft1), (g1,))
61+
@test checkspaces(domain_axes(ft1), (g2,))
6262
@test ndims_codomain(ft1) == 1
6363
@test ndims_domain(ft1) == 1
6464
@test size(data_matrix(ft1)) == (6, 5)
@@ -86,42 +86,42 @@ include("setup.jl")
8686
@test ft2 !== ft1
8787
@test data_matrix(ft2) == data_matrix(ft1)
8888
@test data_matrix(ft2) !== data_matrix(ft1)
89-
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
90-
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
89+
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
90+
@test checkspaces(domain_axes(ft2), domain_axes(ft1))
9191

9292
ft2 = deepcopy(ft1)
9393
@test ft2 !== ft1
9494
@test data_matrix(ft2) == data_matrix(ft1)
9595
@test data_matrix(ft2) !== data_matrix(ft1)
96-
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
97-
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
96+
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
97+
@test checkspaces(domain_axes(ft2), domain_axes(ft1))
9898

9999
# similar
100100
ft2 = similar(ft1)
101101
@test isnothing(check_sanity(ft2))
102102
@test eltype(ft2) == Float64
103-
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
104-
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
103+
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
104+
@test checkspaces(domain_axes(ft2), domain_axes(ft1))
105105

106106
ft3 = similar(ft1, ComplexF64)
107107
@test isnothing(check_sanity(ft3))
108108
@test eltype(ft3) == ComplexF64
109-
@test checkaxes(codomain_axes(ft3), codomain_axes(ft1))
110-
@test checkaxes(domain_axes(ft3), domain_axes(ft1))
109+
@test checkspaces(codomain_axes(ft3), codomain_axes(ft1))
110+
@test checkspaces(domain_axes(ft3), domain_axes(ft1))
111111

112112
@test_throws AssertionError similar(ft1, Int)
113113

114114
ft5 = similar(ft1, ComplexF32, ((g1, g1), (g2,)))
115115
@test isnothing(check_sanity(ft5))
116116
@test eltype(ft5) == ComplexF64
117-
@test checkaxes(codomain_axes(ft5), (g1, g1))
118-
@test checkaxes(domain_axes(ft5), (g2,))
117+
@test checkspaces(codomain_axes(ft5), (g1, g1))
118+
@test checkspaces(domain_axes(ft5), (g2,))
119119

120120
ft5 = similar(ft1, ComplexF32, tuplemortar(((g1, g1), (g2,))))
121121
@test isnothing(check_sanity(ft5))
122122
@test eltype(ft5) == ComplexF64
123-
@test checkaxes(codomain_axes(ft5), (g1, g1))
124-
@test checkaxes(domain_axes(ft5), (g2,))
123+
@test checkspaces(codomain_axes(ft5), (g1, g1))
124+
@test checkspaces(domain_axes(ft5), (g2,))
125125
end
126126

127127
@testset "More than 2 axes" begin
@@ -135,8 +135,8 @@ end
135135
ft = FusionTensor(m2, (g1, g2), (g3, g4))
136136

137137
@test data_matrix(ft) == m2
138-
@test checkaxes(codomain_axes(ft), (g1, g2))
139-
@test checkaxes(domain_axes(ft), (g3, g4))
138+
@test checkspaces(codomain_axes(ft), (g1, g2))
139+
@test checkspaces(domain_axes(ft), (g3, g4))
140140

141141
@test axes(ft) == FusionTensorAxes(tuplemortar(((g1, g2), (g3, g4))))
142142
@test ndims_codomain(ft) == 2
@@ -269,9 +269,9 @@ end
269269
@test isnothing(check_sanity(ad))
270270

271271
ft7 = FusionTensor{Float64}(undef, (g1,), (g2, g3, g4))
272-
@test_throws DimensionMismatch ft7 + ft3
273-
@test_throws DimensionMismatch ft7 - ft3
274-
@test_throws DimensionMismatch ft7 * ft3
272+
@test_throws ArgumentError ft7 + ft3
273+
@test_throws ArgumentError ft7 - ft3
274+
@test_throws ArgumentError ft7 * ft3
275275
end
276276

277277
@testset "specific constructors" begin

test/test_fusiontensoraxes.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test: @test, @testset
1+
using Test: @test, @test_throws, @testset
22

33
using TensorProducts:
44
using BlockArrays: Block, blockedrange, blocklength, blocklengths, blocks
@@ -15,7 +15,16 @@ using FusionTensors:
1515
promote_sector_type,
1616
promote_sectors
1717
using GradedArrays:
18-
×, U1, SectorProduct, TrivialSector, SU2, dual, gradedrange, sector_type, space_isequal
18+
×,
19+
U1,
20+
SectorProduct,
21+
TrivialSector,
22+
SU2,
23+
checkspaces,
24+
dual,
25+
gradedrange,
26+
sector_type,
27+
space_isequal
1928

2029
@testset "misc FusionTensors.jl" begin
2130
g1 = gradedrange([U1(0) => 1])
@@ -83,6 +92,10 @@ end
8392
@test fta != FusionTensorAxes(tuplemortar(((g2, g2, g2b), (g2b,))))
8493

8594
@test fta == FusionTensorAxes((g2, g2), (g2b, g2b))
95+
@test checkspaces(fta, fta)
96+
@test_throws ArgumentError checkspaces(
97+
fta, FusionTensorAxes(tuplemortar(((g2, g2), (g2b, g2))))
98+
)
8699
end
87100

88101
@testset "Empty FusionTensorAxes" begin

test/test_permutedims.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ using Test: @test, @testset, @test_broken, @test_throws
22

33
using FusionTensors:
44
FusionTensor,
5+
FusionTensorAxes,
56
data_matrix,
6-
checkaxes,
77
codomain_axis,
88
domain_axis,
99
naive_permutedims,
@@ -40,11 +40,11 @@ include("setup.jl")
4040
ft3 = permutedims(ft1, (4,), (1, 2, 3))
4141
@test ft3 !== ft1
4242
@test ft3 isa FusionTensor{elt,4}
43-
@test checkaxes(axes(ft3), tuplemortar(((dual(g4),), (g1, g2, dual(g3)))))
43+
@test axes(ft3) == FusionTensorAxes((dual(g4),), (g1, g2, dual(g3)))
4444
@test isnothing(check_sanity(ft3))
4545

4646
ft4 = permutedims(ft3, (2, 3), (4, 1))
47-
@test checkaxes(axes(ft1), axes(ft4))
47+
@test axes(ft1) == axes(ft4)
4848
@test space_isequal(codomain_axis(ft1), codomain_axis(ft4))
4949
@test space_isequal(domain_axis(ft1), domain_axis(ft4))
5050
@test ft4 ft1

0 commit comments

Comments
 (0)