Skip to content

Commit 3f6bd2e

Browse files
committed
passing tests
1 parent cc6d7ea commit 3f6bd2e

File tree

5 files changed

+27
-58
lines changed

5 files changed

+27
-58
lines changed

NDTensors/src/lib/GradedAxes/src/fusion.jl

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,28 @@ function tensor_product(
1212
return foldl(tensor_product, (a1, a2, a3, a_rest...))
1313
end
1414

15-
function tensor_product(::AbstractUnitRange, ::AbstractUnitRange)
16-
return error("Not implemented yet.")
15+
flip_dual(r::AbstractUnitRange) = r
16+
flip_dual(r::GradedUnitRangeDual) = flip(r)
17+
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
18+
return tensor_product(flip_dual(a1), flip_dual(a2))
1719
end
1820

1921
function tensor_product(a1::Base.OneTo, a2::Base.OneTo)
2022
return Base.OneTo(length(a1) * length(a2))
2123
end
2224

23-
function tensor_product(::OneToOne, a2::AbstractBlockedUnitRange)
25+
function tensor_product(::OneToOne, a2::AbstractUnitRange)
2426
return a2
2527
end
2628

27-
function tensor_product(a1::AbstractBlockedUnitRange, ::OneToOne)
29+
function tensor_product(a1::AbstractUnitRange, ::OneToOne)
2830
return a1
2931
end
3032

3133
function tensor_product(::OneToOne, ::OneToOne)
3234
return OneToOne()
3335
end
3436

35-
# Handle dual. Always return a non-dual GradedUnitRange.
36-
function tensor_product(a1::AbstractBlockedUnitRange, a2::GradedUnitRangeDual)
37-
return tensor_product(a1, flip(a2))
38-
end
39-
40-
function tensor_product(a1::GradedUnitRangeDual, a2::AbstractBlockedUnitRange)
41-
return tensor_product(flip(a1), a2)
42-
end
43-
44-
function tensor_product(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual)
45-
return tensor_product(flip(a1), flip(a2))
46-
end
47-
4837
function fuse_labels(x, y)
4938
return error(
5039
"`fuse_labels` not implemented for object of type `$(typeof(x))` and `$(typeof(y))`."
@@ -98,7 +87,8 @@ end
9887
# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`.
9988
# Get the permutation for sorting, then group by common elements.
10089
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
101-
function blockmergesortperm(a::AbstractBlockedUnitRange)
90+
blockmergesort(g::AbstractUnitRange) = g
91+
function blockmergesortperm(a::AbstractUnitRange)
10292
return Block.(groupsortperm(blocklabels(a)))
10393
end
10494

@@ -120,7 +110,6 @@ function blockmergesort(g::AbstractGradedUnitRange)
120110
end
121111

122112
blockmergesort(g::GradedUnitRangeDual) = dual(blockmergesort(flip(g)))
123-
blockmergesort(g::OneToOne) = g
124113

125114
# fusion_product produces a sorted, non-dual GradedUnitRange
126115
function fusion_product(g1, g2)

NDTensors/src/lib/GradedAxes/src/onetoone.jl

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,4 @@ struct OneToOne{T} <: AbstractUnitRange{T} end
66
OneToOne() = OneToOne{Bool}()
77
Base.first(a::OneToOne) = one(eltype(a))
88
Base.last(a::OneToOne) = one(eltype(a))
9-
10-
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
11-
gradedisequal(::AbstractBlockedUnitRange, ::AbstractUnitRange) = false
12-
gradedisequal(::AbstractUnitRange, ::AbstractBlockedUnitRange) = false
13-
gradedisequal(::AbstractBlockedUnitRange, ::OneToOne) = false
14-
gradedisequal(::OneToOne, ::AbstractBlockedUnitRange) = false
15-
function gradedisequal(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
16-
return blockisequal(a1, a2)
17-
end
18-
function gradedisequal(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
19-
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
20-
end
21-
gradedisequal(::GradedUnitRangeDual, ::GradedUnitRange) = false
22-
gradedisequal(::GradedUnitRange, ::GradedUnitRangeDual) = false
23-
function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual)
24-
return gradedisequal(nondual(a1), nondual(a2))
25-
end
26-
27-
gradedisequal(::OneToOne, ::OneToOne) = true
28-
29-
function gradedisequal(::OneToOne, g::AbstractUnitRange)
30-
return !islabelled(eltype(g)) && (first(g) == last(g) == 1)
31-
end
32-
gradedisequal(g::AbstractUnitRange, a0::OneToOne) = gradedisequal(a0, g)
33-
gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) = a1 == a2
9+
BlockArrays.blockaxes(g::OneToOne) = (Block.(g),) # BlockArrays default crashes for OneToOne{Bool}

NDTensors/src/lib/GradedAxes/test/test_basics.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ using BlockArrays:
99
blocklength,
1010
blocklengths,
1111
blocks
12-
using NDTensors.GradedAxes:
13-
GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedisequal, gradedrange
12+
using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange
1413
using NDTensors.LabelledNumbers:
1514
LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel
1615
using Test: @test, @test_broken, @testset
@@ -20,13 +19,14 @@ using Test: @test, @test_broken, @testset
2019
@test a0 isa OneToOne{Bool}
2120
@test eltype(a0) == Bool
2221
@test length(a0) == 1
23-
@test gradedisequal(a0, a0)
22+
@test labelled_isequal(a0, a0)
2423

25-
@test gradedisequal(a0, 1:1)
26-
@test gradedisequal(1:1, a0)
27-
@test !gradedisequal(a0, 1:2)
28-
@test !gradedisequal(1:2, a0)
24+
@test labelled_isequal(a0, 1:1)
25+
@test labelled_isequal(1:1, a0)
26+
@test !labelled_isequal(a0, 1:2)
27+
@test !labelled_isequal(1:2, a0)
2928
end
29+
3030
@testset "GradedAxes basics" begin
3131
a0 = OneToOne()
3232
for a in (
@@ -35,10 +35,10 @@ end
3535
gradedrange(["x" => 2, "y" => 3]),
3636
)
3737
@test a isa GradedOneTo
38-
@test gradedisequal(a, a)
39-
@test !gradedisequal(a0, a)
40-
@test !gradedisequal(a, a0)
41-
@test !gradedisequal(a, 1:5)
38+
@test labelled_isequal(a, a)
39+
@test !labelled_isequal(a0, a)
40+
@test !labelled_isequal(a, a0)
41+
@test !labelled_isequal(a, 1:5)
4242
for x in iterate(a)
4343
@test x == 1
4444
@test label(x) == "x"

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using BlockArrays:
55
blockaxes,
66
blockedrange,
77
blockfirsts,
8+
blockisequal,
89
blocklasts,
910
blocklength,
1011
blocklengths,
@@ -23,7 +24,7 @@ using NDTensors.GradedAxes:
2324
gradedrange,
2425
isdual,
2526
nondual
26-
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled
27+
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal
2728
using Test: @test, @testset
2829
struct U1
2930
n::Int
@@ -36,6 +37,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
3637
@test !isdual(a0)
3738
@test dual(a0) isa OneToOne
3839
@test space_isequal(a0, a0)
40+
@test labelled_isequal(a0, a0)
3941
@test space_isequal(a0, dual(a0))
4042

4143
a = 1:3
@@ -50,7 +52,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
5052
@test !isdual(a)
5153
@test !isdual(ad)
5254
@test ad isa BlockedOneTo
53-
@test space_isequal(ad, a)
55+
@test blockisequal(ad, a)
5456
end
5557

5658
@testset "GradedUnitRangeDual" begin

NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ using NDTensors.GradedAxes:
1111
fusion_product,
1212
flip,
1313
gradedrange,
14-
labelled_isequal,
1514
space_isequal,
1615
isdual,
1716
tensor_product
1817

18+
using NDTensors.LabelledNumbers: labelled_isequal
19+
1920
struct U1
2021
n::Int
2122
end
@@ -27,6 +28,7 @@ GradedAxes.fuse_labels(x::U1, y::U1) = U1(x.n + y.n)
2728
GradedAxes.fuse_labels(x::String, y::String) = x * y
2829

2930
g0 = OneToOne()
31+
@test labelled_isequal(g0, g0)
3032
@test labelled_isequal(tensor_product(g0, g0), g0)
3133

3234
a = gradedrange(["x" => 2, "y" => 3])

0 commit comments

Comments
 (0)