Skip to content

Commit c57bc02

Browse files
committed
fix GradedUnitRangeDual tests
1 parent db71962 commit c57bc02

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

NDTensors/src/lib/GradedAxes/src/dual.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ isdual(::AbstractUnitRange) = false
55

66
using NDTensors.LabelledNumbers:
77
LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel
8+
9+
dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
810
label_dual(x) = label_dual(LabelledStyle(x), x)
911
label_dual(::NotLabelled, x) = x
1012
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))

NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ end
104104
Base.unitrange(a::GradedUnitRangeDual) = a
105105

106106
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel
107-
dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
108-
109107
using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock
110108
BlockArrays.blockaxes(a::GradedUnitRangeDual) = blockaxes(nondual(a))
111109
BlockArrays.blockfirsts(a::GradedUnitRangeDual) = label_dual.(blockfirsts(nondual(a)))

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using BlockArrays:
1212
blocks,
1313
findblock
1414
using NDTensors.GradedAxes:
15+
AbstractGradedUnitRange,
1516
GradedAxes,
1617
GradedUnitRangeDual,
1718
OneToOne,
@@ -60,6 +61,7 @@ end
6061
[gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]]
6162
ad = dual(a)
6263
@test ad isa GradedUnitRangeDual
64+
@test ad isa AbstractGradedUnitRange
6365
@test eltype(ad) == LabelledInteger{Int,U1}
6466
@test blocklengths(ad) isa Vector
6567
@test eltype(blocklengths(ad)) == eltype(blocklengths(a))
@@ -78,6 +80,8 @@ end
7880
@test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))]
7981
@test blocklength(ad) == 2
8082
@test blocklengths(ad) == [2, 3]
83+
@test blocklabels(ad) == [U1(0), U1(-1)]
84+
@test label.(blocklengths(ad)) == [U1(0), U1(-1)]
8185
@test findblock(ad, 4) == Block(2)
8286
@test only(blockaxes(ad)) == Block(1):Block(2)
8387
@test blocks(ad) == [labelled(1:2, U1(0)), labelled(3:5, U1(-1))]

0 commit comments

Comments
 (0)