Skip to content

Commit 22c5e51

Browse files
committed
replace matching_axes with checkaxes
1 parent 9247074 commit 22c5e51

File tree

6 files changed

+42
-40
lines changed

6 files changed

+42
-40
lines changed

src/fusiontensor/base_interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Base.:*(ft::FusionTensor, x::Number) = set_data_matrix(ft, x * data_matrix(ft))
1111

1212
# tensor contraction is a block data_matrix product.
1313
function Base.:*(left::FusionTensor, right::FusionTensor)
14-
@assert matching_dual(domain_axes(left), codomain_axes(right))
14+
checkaxes_dual(domain_axes(left), codomain_axes(right))
1515
new_data_matrix = data_matrix(left) * data_matrix(right)
1616
return fusiontensor(new_data_matrix, codomain_axes(left), domain_axes(right))
1717
end
@@ -20,14 +20,14 @@ Base.:+(ft::FusionTensor) = ft
2020

2121
# tensor addition is a block data_matrix add.
2222
function Base.:+(left::FusionTensor, right::FusionTensor)
23-
@assert matching_axes(axes(left), axes(right))
23+
checkaxes(axes(left), axes(right))
2424
return set_data_matrix(left, data_matrix(left) + data_matrix(right))
2525
end
2626

2727
Base.:-(ft::FusionTensor) = set_data_matrix(ft, -data_matrix(ft))
2828

2929
function Base.:-(left::FusionTensor, right::FusionTensor)
30-
@assert matching_axes(axes(left), axes(right))
30+
checkaxes(axes(left), axes(right))
3131
return set_data_matrix(left, data_matrix(left) - data_matrix(right))
3232
end
3333

src/fusiontensor/fusiontensor.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,11 @@ function initialize_allowed_sectors!(mat::AbstractMatrix)
182182
end
183183
end
184184

185-
matching_dual(axes1::Tuple, axes2::Tuple) = matching_axes(axes1, dual.(axes2))
186-
matching_axes(axes1::Tuple, axes2::Tuple) = false
187-
function matching_axes(axes1::T, axes2::T) where {T<:Tuple}
188-
return all(space_isequal.(axes1, axes2))
185+
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
186+
function checkaxes(ax1, ax2)
187+
return checkaxes(Bool, ax1, ax2) ||
188+
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
189+
end
190+
function checkaxes(::Type{Bool}, axes1, axes2)
191+
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
189192
end

src/fusiontensor/linear_algebra_interface.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,9 @@ function LinearAlgebra.mul!(
2828
end
2929

3030
# input validation
31-
if !matching_dual(domain_axes(A), codomain_axes(B))
32-
throw(codomainError("Incompatible tensor axes for A and B"))
33-
end
34-
if !matching_axes(codomain_axes(C), codomain_axes(A))
35-
throw(codomainError("Incompatible tensor axes for C and A"))
36-
end
37-
if !matching_axes(domain_axes(C), domain_axes(B))
38-
throw(codomainError("Incompatible tensor axes for C and B"))
39-
end
31+
checkaxes_dual(domain_axes(A), codomain_axes(B))
32+
checkaxes(codomain_axes(C), codomain_axes(A))
33+
checkaxes(domain_axes(C), domain_axes(B))
4034
mul!(data_matrix(C), data_matrix(A), data_matrix(B), α, β)
4135
return C
4236
end

test/basics/setup.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ using FusionTensors:
66
codomain_axes,
77
data_matrix,
88
domain_axes,
9-
matching_axes,
10-
matching_dual,
9+
checkaxes,
10+
checkaxes_dual,
1111
matrix_column_axis,
1212
matrix_row_axis,
1313
ndims_codomain,
@@ -37,8 +37,8 @@ function check_sanity(ft::FusionTensor)
3737
@assert nda + nca == ndims(ft) "invalid ndims"
3838

3939
@assert length(axes(ft)) == ndims(ft) "ndims does not match axes"
40-
@assert matching_axes(axes(ft)[begin:nda], codomain_axes(ft)) "axes do not match codomain_axes"
41-
@assert matching_axes(axes(ft)[(nda + 1):end], domain_axes(ft)) "axes do not match domain_axes"
40+
checkaxes(axes(ft)[begin:nda], codomain_axes(ft))
41+
checkaxes(axes(ft)[(nda + 1):end], domain_axes(ft))
4242

4343
m = data_matrix(ft)
4444
@assert ndims(m) == 2 "invalid data_matrix ndims"

test/basics/test_basics.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ using FusionTensors:
77
data_matrix,
88
domain_axes,
99
fusiontensor,
10-
matching_axes,
11-
matching_dual,
10+
checkaxes,
11+
checkaxes_dual,
1212
matrix_column_axis,
1313
matrix_row_axis,
1414
matrix_size,
@@ -34,11 +34,11 @@ include("setup.jl")
3434

3535
# getters
3636
@test data_matrix(ft1) == m
37-
@test matching_axes(codomain_axes(ft1), (g1,))
38-
@test matching_axes(domain_axes(ft1), (g2,))
37+
@test checkaxes(codomain_axes(ft1), (g1,))
38+
@test checkaxes(domain_axes(ft1), (g2,))
3939

4040
# misc
41-
@test matching_axes(axes(ft1), (g1, g2))
41+
@test checkaxes(axes(ft1), (g1, g2))
4242
@test ndims_codomain(ft1) == 1
4343
@test ndims_domain(ft1) == 1
4444
@test matrix_size(ft1) == (6, 5)
@@ -60,36 +60,36 @@ include("setup.jl")
6060
@test ft2 !== ft1
6161
@test data_matrix(ft2) == data_matrix(ft1)
6262
@test data_matrix(ft2) !== data_matrix(ft1)
63-
@test matching_axes(codomain_axes(ft2), codomain_axes(ft1))
64-
@test matching_axes(domain_axes(ft2), domain_axes(ft1))
63+
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
64+
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
6565

6666
ft2 = deepcopy(ft1)
6767
@test ft2 !== ft1
6868
@test data_matrix(ft2) == data_matrix(ft1)
6969
@test data_matrix(ft2) !== data_matrix(ft1)
70-
@test matching_axes(codomain_axes(ft2), codomain_axes(ft1))
71-
@test matching_axes(domain_axes(ft2), domain_axes(ft1))
70+
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
71+
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
7272

7373
# similar
7474
ft2 = similar(ft1)
7575
@test isnothing(check_sanity(ft2))
7676
@test eltype(ft2) == Float64
77-
@test matching_axes(codomain_axes(ft2), codomain_axes(ft1))
78-
@test matching_axes(domain_axes(ft2), domain_axes(ft1))
77+
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
78+
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
7979

8080
ft3 = similar(ft1, ComplexF64)
8181
@test isnothing(check_sanity(ft3))
8282
@test eltype(ft3) == ComplexF64
83-
@test matching_axes(codomain_axes(ft3), codomain_axes(ft1))
84-
@test matching_axes(domain_axes(ft3), domain_axes(ft1))
83+
@test checkaxes(codomain_axes(ft3), codomain_axes(ft1))
84+
@test checkaxes(domain_axes(ft3), domain_axes(ft1))
8585

8686
@test_throws AssertionError similar(ft1, Int)
8787

8888
ft5 = similar(ft1, ComplexF32, ((g1, g1), (g2,)))
8989
@test isnothing(check_sanity(ft5))
9090
@test eltype(ft5) == ComplexF64
91-
@test matching_axes(codomain_axes(ft5), (g1, g1))
92-
@test matching_axes(domain_axes(ft5), (g2,))
91+
@test checkaxes(codomain_axes(ft5), (g1, g1))
92+
@test checkaxes(domain_axes(ft5), (g2,))
9393
end
9494

9595
@testset "More than 2 axes" begin
@@ -103,8 +103,8 @@ end
103103
ft = fusiontensor(m2, (g1, g2), (g3, g4))
104104

105105
@test data_matrix(ft) == m2
106-
@test matching_axes(codomain_axes(ft), (g1, g2))
107-
@test matching_axes(domain_axes(ft), (g3, g4))
106+
@test checkaxes(codomain_axes(ft), (g1, g2))
107+
@test checkaxes(domain_axes(ft), (g3, g4))
108108

109109
@test axes(ft) == (g1, g2, g3, g4)
110110
@test ndims_codomain(ft) == 2
@@ -232,6 +232,11 @@ end
232232
@test space_isequal(dual(g3), codomain_axes(ad)[1])
233233
@test space_isequal(dual(g4), codomain_axes(ad)[2])
234234
@test isnothing(check_sanity(ad))
235+
236+
ft7 = FusionTensor(Float64, (g1,), (g2, g3, g4))
237+
@test_throws DimensionMismatch ft7 + ft3
238+
@test_throws DimensionMismatch ft7 - ft3
239+
@test_throws DimensionMismatch ft7 * ft3
235240
end
236241

237242
@testset "mising SectorProduct" begin

test/basics/test_permutedims.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Test: @test, @testset, @test_broken
44
using FusionTensors:
55
FusionTensor,
66
data_matrix,
7-
matching_axes,
7+
checkaxes,
88
matrix_column_axis,
99
matrix_row_axis,
1010
naive_permutedims,
@@ -41,14 +41,14 @@ include("setup.jl")
4141
ft3 = permutedims(ft1, (4,), (1, 2, 3))
4242
@test ft3 !== ft1
4343
@test ft3 isa FusionTensor{elt,4}
44-
@test matching_axes(axes(ft3), (dual(g4), g1, g2, dual(g3)))
44+
@test checkaxes(axes(ft3), (dual(g4), g1, g2, dual(g3)))
4545
@test ndims_domain(ft3) == 3
4646
@test ndims_codomain(ft3) == 1
4747
@test ndims(ft3) == 4
4848
@test isnothing(check_sanity(ft3))
4949

5050
ft4 = permutedims(ft3, (2, 3), (4, 1))
51-
@test matching_axes(axes(ft1), axes(ft4))
51+
@test checkaxes(axes(ft1), axes(ft4))
5252
@test space_isequal(matrix_column_axis(ft1), matrix_column_axis(ft4))
5353
@test space_isequal(matrix_row_axis(ft1), matrix_row_axis(ft4))
5454
@test ft4 ft1

0 commit comments

Comments
 (0)