Skip to content

Commit 5d047a8

Browse files
committed
define labelledunitrangedual
1 parent 89adeb2 commit 5d047a8

File tree

5 files changed

+49
-6
lines changed

5 files changed

+49
-6
lines changed

NDTensors/src/lib/GradedAxes/src/GradedAxes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GradedAxes
22
include("blockedunitrange.jl")
33
include("gradedunitrange.jl")
44
include("dual.jl")
5+
include("labelledunitrangedual.jl")
56
include("gradedunitrangedual.jl")
67
include("onetoone.jl")
78
include("fusion.jl")

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ function blockedunitrange_getindices(
250250
# if they exist. This makes it so that
251251
# `only(axes(a[indices])) isa `GradedUnitRange`
252252
# if `a isa `GradedUnitRange`, for example.
253-
return mortar(blocks, length.(blocks))
253+
return mortar(blocks, length.(blocks)) # LOOSE DUAL
254254
end
255255

256256
# The block labels of the corresponding slice.

NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@ function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer)
3131
end
3232

3333
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
34-
return label_dual(getindex(nondual(a), indices))
34+
return dual(getindex(nondual(a), indices))
3535
end
3636

3737
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange)
38-
return label_dual(getindex(nondual(a), indices))
38+
return dual(getindex(nondual(a), indices))
39+
end
40+
41+
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockIndexRange)
42+
return dual(nondual(a)[indices])
3943
end
4044

4145
# fix ambiguity
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# LabelledUnitRangeDual is obtained by slicing a GradedUnitRangeDual with a block
2+
3+
using ..LabelledNumbers: LabelledNumbers, label, labelled, unlabel
4+
5+
struct LabelledUnitRangeDual{T,NondualUnitRange<:AbstractUnitRange{T}} <:
6+
AbstractUnitRange{T}
7+
nondual_unitrange::NondualUnitRange
8+
end
9+
10+
dual(a::LabelledUnitRange) = LabelledUnitRangeDual(a)
11+
nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange
12+
dual(a::LabelledUnitRangeDual) = nondual(a)
13+
flip(a::LabelledUnitRangeDual) = dual(flip(nondual(a)))
14+
isdual(::LabelledUnitRangeDual) = true
15+
16+
LabelledNumbers.label(a::LabelledUnitRangeDual) = dual(label(nondual(a)))
17+
LabelledNumbers.unlabel(a::LabelledUnitRangeDual) = unlabel(nondual(a))
18+
19+
for f in [:first, :getindex, :last, :length, :step]
20+
@eval Base.$f(a::LabelledUnitRangeDual, args...) =
21+
labelled($f(unlabel(a), args...), label(a))
22+
end
23+
24+
# fix ambiguities
25+
Base.getindex(a::LabelledUnitRangeDual, i::Integer) = dual(nondual(a)[i])
26+
27+
function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRangeDual)
28+
println(io, typeof(a))
29+
return print(io, label(a), " => ", unlabel(a))
30+
end
31+
32+
function Base.show(io::IO, a::LabelledUnitRangeDual)
33+
return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a))
34+
end
35+
36+
function Base.AbstractUnitRange{T}(a::LabelledUnitRangeDual) where {T}
37+
return AbstractUnitRange{T}(nondual(a))
38+
end

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,13 @@ end
124124
@test blockmergesortperm(a) == [Block(1), Block(2)]
125125
@test blockmergesortperm(ad) == [Block(1), Block(2)]
126126

127-
@test_broken isdual(ad[Block(1)])
128-
@test_broken isdual(ad[Block(1)[1:1]])
127+
@test isdual(ad[Block(1)])
128+
@test isdual(ad[Block(1)[1:1]])
129129
I = mortar([Block(2)[1:1]])
130130
g = ad[I]
131131
@test length(g) == 1
132132
@test label(first(g)) == U1(-1)
133-
@test_broken isdual(g[Block(1)])
133+
@test isdual(g[Block(1)])
134134
end
135135
end
136136

0 commit comments

Comments
 (0)