Skip to content

Commit 3c49d34

Browse files
committed
fix tests and debug
1 parent e53a3e0 commit 3c49d34

File tree

5 files changed

+14
-14
lines changed

5 files changed

+14
-14
lines changed

src/fusiontensor/base_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ end
1919

2020
# tensor addition is a block data_matrix add.
2121
function Base.:+(left::FusionTensor, right::FusionTensor)
22-
checkspaces(axes(left), axes(right))
22+
axes(left) == axes(right) || throw(ArgumentError("Axes do not match"))
2323
return set_data_matrix(left, data_matrix(left) + data_matrix(right))
2424
end
2525

2626
Base.:-(ft::FusionTensor) = set_data_matrix(ft, -data_matrix(ft))
2727

2828
function Base.:-(left::FusionTensor, right::FusionTensor)
29-
checkspaces(axes(left), axes(right))
29+
axes(left) == axes(right) || throw(ArgumentError("Axes do not match"))
3030
return set_data_matrix(left, data_matrix(left) - data_matrix(right))
3131
end
3232

src/fusiontensor/fusiontensor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ end
244244
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
245245
# find outer block corresponding to fusion trees
246246
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
247-
b1 = findfirstblock.(codomain_axes(ft), leaves(f1))
248-
b2 = findfirstblock.(domain_axes(ft), leaves(f2))
247+
b1 = findfirstblock.(flip_dual.(codomain_axes(ft)), leaves(f1))
248+
b2 = findfirstblock.(flip_dual.(domain_axes(ft)), leaves(f2))
249249
return Block(Int.(b1)..., Int.(b2)...)
250250
end
251251

src/fusiontensor/linear_algebra_interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using LinearAlgebra: LinearAlgebra, mul!, norm, tr
55
using BlockArrays: Block, blocks
66

77
using BlockSparseArrays: eachblockstoredindex
8-
using GradedArrays: quantum_dimension, sectors
8+
using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors
99

1010
# allow to contract with different eltype and let BlockSparseArray ensure compatibility
1111
# impose matching type and number of axes at compile time
@@ -27,9 +27,9 @@ function LinearAlgebra.mul!(
2727
end
2828

2929
# input validation
30-
checkaxes_dual(domain_axes(A), codomain_axes(B))
31-
checkaxes(codomain_axes(C), codomain_axes(A))
32-
checkaxes(domain_axes(C), domain_axes(B))
30+
checkspaces_dual(domain_axes(A), codomain_axes(B))
31+
checkspaces(codomain_axes(C), codomain_axes(A))
32+
checkspaces(domain_axes(C), domain_axes(B))
3333
mul!(data_matrix(C), data_matrix(A), data_matrix(B), α, β)
3434
return C
3535
end

test/test_basics.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ include("setup.jl")
5454

5555
# getters
5656
@test data_matrix(ft1) == m
57-
@test checkspaces(axes(ft1), tuplemortar(((g1,), (g2,))))
57+
@test axes(ft1) == FusionTensorAxes((g1,), (g2,))
5858

5959
# misc
6060
@test checkspaces(codomain_axes(ft1), (g1,))
@@ -269,8 +269,8 @@ end
269269
@test isnothing(check_sanity(ad))
270270

271271
ft7 = FusionTensor{Float64}(undef, (g1,), (g2, g3, g4))
272-
@test_throws DimensionMismatch ft7 + ft3
273-
@test_throws DimensionMismatch ft7 - ft3
272+
@test_throws ArgumentError ft7 + ft3
273+
@test_throws ArgumentError ft7 - ft3
274274
@test_throws ArgumentError ft7 * ft3
275275
end
276276

test/test_permutedims.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ using Test: @test, @testset, @test_broken, @test_throws
22

33
using FusionTensors:
44
FusionTensor,
5+
FusionTensorAxes,
56
data_matrix,
6-
checkaxes,
77
codomain_axis,
88
domain_axis,
99
naive_permutedims,
@@ -40,11 +40,11 @@ include("setup.jl")
4040
ft3 = permutedims(ft1, (4,), (1, 2, 3))
4141
@test ft3 !== ft1
4242
@test ft3 isa FusionTensor{elt,4}
43-
@test checkaxes(axes(ft3), tuplemortar(((dual(g4),), (g1, g2, dual(g3)))))
43+
@test axes(ft3) == FusionTensorAxes((dual(g4),), (g1, g2, dual(g3)))
4444
@test isnothing(check_sanity(ft3))
4545

4646
ft4 = permutedims(ft3, (2, 3), (4, 1))
47-
@test checkaxes(axes(ft1), axes(ft4))
47+
@test axes(ft1) == axes(ft4)
4848
@test space_isequal(codomain_axis(ft1), codomain_axis(ft4))
4949
@test space_isequal(domain_axis(ft1), domain_axis(ft4))
5050
@test ft4 ft1

0 commit comments

Comments
 (0)