Skip to content

Commit 6555e0e

Browse files
authored
[GradedAxes] Introduce LabelledUnitRangeDual (#1571)
1 parent 3838f62 commit 6555e0e

File tree

8 files changed

+216
-24
lines changed

8 files changed

+216
-24
lines changed

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@eval module $(gensym())
22
using Compat: Returns
3-
using Test: @test, @testset, @test_broken
3+
using Test: @test, @testset
44
using BlockArrays:
55
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
66
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
@@ -217,10 +217,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
217217
@test size(a[I, I]) == (1, 1)
218218
@test isdual(axes(a[I, :], 2))
219219
@test isdual(axes(a[:, I], 1))
220-
@test_broken isdual(axes(a[I, :], 1))
221-
@test_broken isdual(axes(a[:, I], 2))
222-
@test_broken isdual(axes(a[I, I], 1))
223-
@test_broken isdual(axes(a[I, I], 2))
220+
@test isdual(axes(a[I, :], 1))
221+
@test isdual(axes(a[:, I], 2))
222+
@test isdual(axes(a[I, I], 1))
223+
@test isdual(axes(a[I, I], 2))
224224
end
225225

226226
@testset "dual GradedUnitRange" begin
@@ -243,10 +243,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
243243
@test size(a[I, I]) == (1, 1)
244244
@test isdual(axes(a[I, :], 2))
245245
@test isdual(axes(a[:, I], 1))
246-
@test_broken isdual(axes(a[I, :], 1))
247-
@test_broken isdual(axes(a[:, I], 2))
248-
@test_broken isdual(axes(a[I, I], 1))
249-
@test_broken isdual(axes(a[I, I], 2))
246+
@test isdual(axes(a[I, :], 1))
247+
@test isdual(axes(a[:, I], 2))
248+
@test isdual(axes(a[I, I], 1))
249+
@test isdual(axes(a[I, I], 2))
250250
end
251251

252252
@testset "dual BlockedUnitRange" begin # self dual

NDTensors/src/lib/GradedAxes/src/GradedAxes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GradedAxes
22
include("blockedunitrange.jl")
33
include("gradedunitrange.jl")
44
include("dual.jl")
5+
include("labelledunitrangedual.jl")
56
include("gradedunitrangedual.jl")
67
include("onetoone.jl")
78
include("fusion.jl")

NDTensors/src/lib/GradedAxes/src/dual.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# default behavior: self-dual
2-
dual(r::AbstractUnitRange) = r
1+
# default behavior: any object is self-dual
2+
dual(x) = x
33
nondual(r::AbstractUnitRange) = r
44
isdual(::AbstractUnitRange) = false
55

@@ -11,4 +11,5 @@ label_dual(x) = label_dual(LabelledStyle(x), x)
1111
label_dual(::NotLabelled, x) = x
1212
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))
1313

14+
flip(a::AbstractUnitRange) = dual(label_dual(a))
1415
flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g))))

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ end
6868
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
6969
struct NoLabel end
7070
blocklabels(r::AbstractUnitRange) = Fill(NoLabel(), blocklength(r))
71+
blocklabels(la::LabelledUnitRange) = [label(la)]
7172

7273
function LabelledNumbers.labelled_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
7374
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))

NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@ function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer)
3131
end
3232

3333
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
34-
return label_dual(getindex(nondual(a), indices))
34+
return dual(getindex(nondual(a), indices))
3535
end
3636

3737
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange)
38-
return label_dual(getindex(nondual(a), indices))
38+
return dual(getindex(nondual(a), indices))
39+
end
40+
41+
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockIndexRange)
42+
return dual(nondual(a)[indices])
3943
end
4044

4145
# fix ambiguity
@@ -49,20 +53,51 @@ function BlockArrays.blocklengths(a::GradedUnitRangeDual)
4953
return dual.(blocklengths(nondual(a)))
5054
end
5155

52-
function gradedunitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices)
56+
# TODO: Move this to a `BlockArraysExtensions` library.
57+
function blockedunitrange_getindices(
58+
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
59+
)
5360
a_indices = getindex(nondual(a), indices)
54-
return mortar([label_dual(b) for b in blocks(a_indices)])
61+
v = mortar(dual.(blocks(a_indices)))
62+
# flip v to stay consistent with other cases where axes(v) are used
63+
return flip_blockvector(v)
5564
end
5665

57-
# TODO: Move this to a `BlockArraysExtensions` library.
58-
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}})
59-
return gradedunitrangedual_getindices_blocks(a, indices)
66+
function blockedunitrange_getindices(
67+
a::GradedUnitRangeDual,
68+
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
69+
)
70+
v = mortar(map(b -> a[b], blocks(indices)))
71+
# GradedOneTo appears in mortar
72+
# flip v axis to preserve dual information
73+
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]]))
74+
return flip_blockvector(v)
6075
end
6176

6277
function blockedunitrange_getindices(
63-
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
78+
a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
6479
)
65-
return gradedunitrangedual_getindices_blocks(a, indices)
80+
# Without converting `indices` to `Vector`,
81+
# mapping `indices` outputs a `BlockVector`
82+
# which is harder to reason about.
83+
vblocks = map(index -> a[index], Vector(indices))
84+
# We pass `length.(blocks)` to `mortar` in order
85+
# to pass block labels to the axes of the output,
86+
# if they exist. This makes it so that
87+
# `only(axes(a[indices])) isa `GradedUnitRange`
88+
# if `a isa `GradedUnitRange`, for example.
89+
90+
v = mortar(vblocks, length.(vblocks))
91+
# GradedOneTo appears in mortar
92+
# flip v axis to preserve dual information
93+
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)]))
94+
return flip_blockvector(v)
95+
end
96+
97+
function flip_blockvector(v::BlockVector)
98+
block_axes = flip.(axes(v))
99+
flipped = mortar(vec.(blocks(v)), block_axes)
100+
return flipped
66101
end
67102

68103
Base.axes(a::GradedUnitRangeDual) = axes(nondual(a))
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# LabelledUnitRangeDual is obtained by slicing a GradedUnitRangeDual with a block
2+
3+
using ..LabelledNumbers: LabelledNumbers, label, labelled, unlabel
4+
5+
struct LabelledUnitRangeDual{T,NondualUnitRange<:AbstractUnitRange{T}} <:
6+
AbstractUnitRange{T}
7+
nondual_unitrange::NondualUnitRange
8+
end
9+
10+
dual(a::LabelledUnitRange) = LabelledUnitRangeDual(a)
11+
nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange
12+
dual(a::LabelledUnitRangeDual) = nondual(a)
13+
label_dual(::IsLabelled, a::LabelledUnitRangeDual) = dual(label_dual(nondual(a)))
14+
isdual(::LabelledUnitRangeDual) = true
15+
blocklabels(la::LabelledUnitRangeDual) = [label(la)]
16+
17+
LabelledNumbers.label(a::LabelledUnitRangeDual) = dual(label(nondual(a)))
18+
LabelledNumbers.unlabel(a::LabelledUnitRangeDual) = unlabel(nondual(a))
19+
LabelledNumbers.LabelledStyle(::LabelledUnitRangeDual) = IsLabelled()
20+
21+
for f in [:first, :getindex, :last, :length, :step]
22+
@eval Base.$f(a::LabelledUnitRangeDual, args...) =
23+
labelled($f(unlabel(a), args...), label(a))
24+
end
25+
26+
# fix ambiguities
27+
Base.getindex(a::LabelledUnitRangeDual, i::Integer) = dual(nondual(a)[i])
28+
function Base.getindex(a::LabelledUnitRangeDual, indices::AbstractUnitRange{<:Integer})
29+
return dual(nondual(a)[indices])
30+
end
31+
32+
function Base.iterate(a::LabelledUnitRangeDual, i)
33+
i == last(a) && return nothing
34+
next = convert(eltype(a), labelled(i + step(a), label(a)))
35+
return (next, next)
36+
end
37+
38+
function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRangeDual)
39+
println(io, typeof(a))
40+
return print(io, label(a), " => ", unlabel(a))
41+
end
42+
43+
function Base.show(io::IO, a::LabelledUnitRangeDual)
44+
return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a))
45+
end
46+
47+
function Base.AbstractUnitRange{T}(a::LabelledUnitRangeDual) where {T}
48+
return AbstractUnitRange{T}(nondual(a))
49+
end

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using NDTensors.GradedAxes:
1717
AbstractGradedUnitRange,
1818
GradedAxes,
1919
GradedUnitRangeDual,
20+
LabelledUnitRangeDual,
2021
OneToOne,
2122
blocklabels,
2223
blockmergesortperm,
@@ -27,7 +28,8 @@ using NDTensors.GradedAxes:
2728
gradedrange,
2829
isdual,
2930
nondual
30-
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal
31+
using NDTensors.LabelledNumbers:
32+
LabelledInteger, LabelledUnitRange, label, label_type, labelled, labelled_isequal, unlabel
3133
using Test: @test, @test_broken, @testset
3234
struct U1
3335
n::Int
@@ -58,6 +60,92 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n
5860
@test blockisequal(ad, a)
5961
end
6062

63+
@testset "LabelledUnitRangeDual" begin
64+
la = labelled(1:2, U1(1))
65+
@test la isa LabelledUnitRange
66+
@test label(la) == U1(1)
67+
@test blocklabels(la) == [U1(1)]
68+
@test unlabel(la) == 1:2
69+
@test la == 1:2
70+
@test !isdual(la)
71+
@test labelled_isequal(la, la)
72+
@test space_isequal(la, la)
73+
@test label_type(la) == U1
74+
75+
@test iterate(la) == (1, 1)
76+
@test iterate(la) == (1, 1)
77+
@test iterate(la, 1) == (2, 2)
78+
@test isnothing(iterate(la, 2))
79+
80+
lad = dual(la)
81+
@test lad isa LabelledUnitRangeDual
82+
@test label(lad) == U1(-1)
83+
@test blocklabels(lad) == [U1(-1)]
84+
@test unlabel(lad) == 1:2
85+
@test lad == 1:2
86+
@test labelled_isequal(lad, lad)
87+
@test space_isequal(lad, lad)
88+
@test !labelled_isequal(la, lad)
89+
@test !space_isequal(la, lad)
90+
@test isdual(lad)
91+
@test nondual(lad) === la
92+
@test dual(lad) === la
93+
@test label_type(lad) == U1
94+
95+
@test iterate(lad) == (1, 1)
96+
@test iterate(lad) == (1, 1)
97+
@test iterate(lad, 1) == (2, 2)
98+
@test isnothing(iterate(lad, 2))
99+
100+
lad2 = lad[1:1]
101+
@test lad2 isa LabelledUnitRangeDual
102+
@test label(lad2) == U1(-1)
103+
@test unlabel(lad2) == 1:1
104+
105+
laf = flip(la)
106+
@test laf isa LabelledUnitRangeDual
107+
@test label(laf) == U1(1)
108+
@test unlabel(laf) == 1:2
109+
@test labelled_isequal(la, laf)
110+
@test !space_isequal(la, laf)
111+
112+
ladf = flip(dual(la))
113+
@test ladf isa LabelledUnitRange
114+
@test label(ladf) == U1(-1)
115+
@test unlabel(ladf) == 1:2
116+
117+
lafd = dual(flip(la))
118+
@test lafd isa LabelledUnitRange
119+
@test label(lafd) == U1(-1)
120+
@test unlabel(lafd) == 1:2
121+
122+
# check default behavior for objects without dual
123+
la = labelled(1:2, 'x')
124+
lad = dual(la)
125+
@test lad isa LabelledUnitRangeDual
126+
@test label(lad) == 'x'
127+
@test blocklabels(lad) == ['x']
128+
@test unlabel(lad) == 1:2
129+
@test lad == 1:2
130+
@test labelled_isequal(lad, lad)
131+
@test space_isequal(lad, lad)
132+
@test labelled_isequal(la, lad)
133+
@test !space_isequal(la, lad)
134+
@test isdual(lad)
135+
@test nondual(lad) === la
136+
@test dual(lad) === la
137+
138+
laf = flip(la)
139+
@test laf isa LabelledUnitRangeDual
140+
@test label(laf) == 'x'
141+
@test unlabel(laf) == 1:2
142+
143+
ladf = flip(lad)
144+
@test ladf isa LabelledUnitRange
145+
@test label(ladf) == 'x'
146+
@test unlabel(ladf) == 1:2
147+
end
148+
61149
@testset "GradedUnitRangeDual" begin
62150
for a in
63151
[gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]]
@@ -124,13 +212,21 @@ end
124212
@test blockmergesortperm(a) == [Block(1), Block(2)]
125213
@test blockmergesortperm(ad) == [Block(1), Block(2)]
126214

127-
@test_broken isdual(ad[Block(1)])
128-
@test_broken isdual(ad[Block(1)[1:1]])
215+
@test isdual(ad[Block(1)])
216+
@test isdual(ad[Block(1)[1:1]])
217+
@test ad[Block(1)] isa LabelledUnitRangeDual
218+
@test ad[Block(1)[1:1]] isa LabelledUnitRangeDual
219+
@test label(ad[Block(2)]) == U1(-1)
220+
@test label(ad[Block(2)[1:1]]) == U1(-1)
221+
129222
I = mortar([Block(2)[1:1]])
130223
g = ad[I]
131224
@test length(g) == 1
132225
@test label(first(g)) == U1(-1)
133-
@test_broken isdual(g[Block(1)])
226+
@test isdual(g[Block(1)])
227+
228+
@test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)])
229+
@test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]])
134230
end
135231
end
136232

NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,12 @@ function Base.iterate(a::LabelledUnitRange, i)
5252
next = convert(eltype(a), labelled(i + step(a), label(a)))
5353
return (next, next)
5454
end
55+
56+
function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRange)
57+
println(io, typeof(a))
58+
return print(io, label(a), " => ", unlabel(a))
59+
end
60+
61+
function Base.show(io::IO, a::LabelledUnitRange)
62+
return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a))
63+
end

0 commit comments

Comments
 (0)