Skip to content

Commit 14c95ca

Browse files
committed
generalize gradedisequal
1 parent 113a1fe commit 14c95ca

File tree

8 files changed

+77
-27
lines changed

8 files changed

+77
-27
lines changed

NDTensors/src/lib/GradedAxes/src/GradedAxes.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module GradedAxes
22
include("blockedunitrange.jl")
33
include("gradedunitrange.jl")
44
include("dual.jl")
5-
include("gradedunitrangedual.jl")
65
include("unitrangedual.jl")
6+
include("gradedunitrangedual.jl")
7+
include("onetoone.jl")
78
include("fusion.jl")
89
end

NDTensors/src/lib/GradedAxes/src/fusion.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,5 @@
11
using BlockArrays: AbstractBlockedUnitRange
22

3-
# Represents the range `1:1` or `Base.OneTo(1)`.
4-
struct OneToOne{T} <: AbstractUnitRange{T} end
5-
OneToOne() = OneToOne{Bool}()
6-
Base.first(a::OneToOne) = one(eltype(a))
7-
Base.last(a::OneToOne) = one(eltype(a))
8-
9-
gradedisequal(::AbstractUnitRange, ::OneToOne) = false
10-
gradedisequal(::OneToOne, ::AbstractUnitRange) = false
11-
gradedisequal(::OneToOne, ::OneToOne) = true
12-
133
# https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl
144
# https://en.wikipedia.org/wiki/Tensor_product
155
# https://github.com/KeitaNakamura/Tensorial.jl

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@ function Base.OrdinalRange{T,T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T}
3737
return unlabel_blocks(a)
3838
end
3939

40-
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
41-
function gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
42-
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
43-
end
44-
4540
# This is only needed in certain Julia versions below 1.10
4641
# (for example Julia 1.6).
4742
# TODO: Delete this once we drop Julia 1.6 support.

NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,6 @@ end
117117

118118
blocklabels(a::GradedUnitRangeDual) = dual.(blocklabels(nondual(a)))
119119

120-
gradedisequal(::GradedUnitRangeDual, ::AbstractGradedUnitRange) = false
121-
gradedisequal(::AbstractGradedUnitRange, ::GradedUnitRangeDual) = false
122-
function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual)
123-
return gradedisequal(nondual(a1), nondual(a2))
124-
end
125120
function BlockArrays.combine_blockaxes(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual)
126121
return dual(combine_blockaxes(dual(a1), dual(a2)))
127122
end
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using BlockArrays: AbstractBlockedUnitRange
2+
using ..LabelledNumbers: islabelled
3+
4+
# Represents the range `1:1` or `Base.OneTo(1)`.
5+
struct OneToOne{T} <: AbstractUnitRange{T} end
6+
OneToOne() = OneToOne{Bool}()
7+
Base.first(a::OneToOne) = one(eltype(a))
8+
Base.last(a::OneToOne) = one(eltype(a))
9+
10+
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
11+
gradedisequal(::AbstractBlockedUnitRange, ::AbstractUnitRange) = false
12+
gradedisequal(::AbstractUnitRange, ::AbstractBlockedUnitRange) = false
13+
gradedisequal(::AbstractBlockedUnitRange, ::OneToOne) = false
14+
gradedisequal(::OneToOne, ::AbstractBlockedUnitRange) = false
15+
function gradedisequal(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
16+
return blockisequal(a1, a2)
17+
end
18+
function gradedisequal(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
19+
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
20+
end
21+
gradedisequal(::GradedUnitRangeDual, ::GradedUnitRange) = false
22+
gradedisequal(::GradedUnitRange, ::GradedUnitRangeDual) = false
23+
function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual)
24+
return gradedisequal(nondual(a1), nondual(a2))
25+
end
26+
27+
gradedisequal(::OneToOne, ::OneToOne) = true
28+
29+
function gradedisequal(::OneToOne, g::AbstractUnitRange)
30+
return !islabelled(eltype(g)) && (first(g) == last(g) == 1)
31+
end
32+
gradedisequal(g::AbstractUnitRange, a0::OneToOne) = gradedisequal(a0, g)
33+
34+
gradedisequal(::UnitRangeDual, ::AbstractUnitRange) = false
35+
gradedisequal(::AbstractUnitRange, ::UnitRangeDual) = false
36+
gradedisequal(::OneToOne, ::UnitRangeDual) = false
37+
gradedisequal(::UnitRangeDual, ::OneToOne) = false
38+
function gradedisequal(a1::UnitRangeDual, a2::UnitRangeDual)
39+
return gradedisequal(nondual(a1), nondual(a2))
40+
end
41+
42+
gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) = a1 == a2

NDTensors/src/lib/GradedAxes/src/unitrangedual.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,6 @@ function BlockArrays.findblock(a::UnitRangeDual, index::Integer)
107107
return findblock(nondual(a), index)
108108
end
109109

110-
gradedisequal(::UnitRangeDual, ::AbstractGradedUnitRange) = false
111-
gradedisequal(::AbstractGradedUnitRange, ::UnitRangeDual) = false
112-
function gradedisequal(a1::UnitRangeDual, a2::UnitRangeDual)
113-
return gradedisequal(nondual(a1), nondual(a2))
114-
end
115110
function BlockArrays.combine_blockaxes(a1::UnitRangeDual, a2::UnitRangeDual)
116111
return dual(combine_blockaxes(dual(a1), dual(a2)))
117112
end

NDTensors/src/lib/GradedAxes/test/test_basics.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,35 @@ using BlockArrays:
99
blocklength,
1010
blocklengths,
1111
blocks
12-
using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, blocklabels, gradedrange
12+
using NDTensors.GradedAxes:
13+
GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedisequal, gradedrange
1314
using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, unlabel
1415
using Test: @test, @test_broken, @testset
16+
17+
@testset "OneToOne" begin
18+
a0 = OneToOne()
19+
@test a0 isa OneToOne{Bool}
20+
@test eltype(a0) == Bool
21+
@test length(a0) == 1
22+
@test gradedisequal(a0, a0)
23+
24+
@test gradedisequal(a0, 1:1)
25+
@test gradedisequal(1:1, a0)
26+
@test !gradedisequal(a0, 1:2)
27+
@test !gradedisequal(1:2, a0)
28+
end
1529
@testset "GradedAxes basics" begin
30+
a0 = OneToOne()
1631
for a in (
1732
blockedrange([labelled(2, "x"), labelled(3, "y")]),
1833
gradedrange([labelled(2, "x"), labelled(3, "y")]),
1934
gradedrange(["x" => 2, "y" => 3]),
2035
)
2136
@test a isa GradedOneTo
37+
@test gradedisequal(a, a)
38+
@test !gradedisequal(a0, a)
39+
@test !gradedisequal(a, a0)
40+
@test !gradedisequal(a, 1:5)
2241
for x in iterate(a)
2342
@test x == 1
2443
@test label(x) == "x"

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
4343
@test isdual(ad)
4444
@test !isdual(a)
4545
@test length(ad) == 1
46+
@test !gradedisequal(a, ad)
47+
@test !gradedisequal(ad, a)
48+
@test gradedisequal(ad, ad)
4649
end
4750
@testset "dual(UnitRange)" begin
4851
a = 1:3
@@ -55,6 +58,16 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
5558
@test isdual(ad)
5659
@test !isdual(a)
5760
@test length(ad) == 3
61+
62+
@test !gradedisequal(ad, a)
63+
@test !gradedisequal(a, ad)
64+
@test gradedisequal(ad, ad)
65+
66+
a0 = OneToOne()
67+
@test !gradedisequal(ad, a0)
68+
@test !gradedisequal(a0, ad)
69+
@test !gradedisequal(dual(a0), ad)
70+
@test !gradedisequal(ad, dual(a0))
5871
end
5972
@testset "dual(BlockedOneTo)" begin
6073
a = blockedrange([2, 3])

0 commit comments

Comments
 (0)