Skip to content

Commit f4aa481

Browse files
authored
[GradedAxes] Simplify dual graded unit range slicing (#1583)
1 parent 4ac4315 commit f4aa481

File tree

2 files changed

+44
-44
lines changed

2 files changed

+44
-44
lines changed

NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -57,62 +57,41 @@ end
5757
function blockedunitrange_getindices(
5858
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
5959
)
60-
a_indices = getindex(nondual(a), indices)
61-
v = mortar(dual.(blocks(a_indices)))
62-
# flip v to stay consistent with other cases where axes(v) are used
63-
return flip_blockvector(v)
60+
# dual v axes to stay consistent with other cases where axes(v) are used
61+
return dual_axes(blockedunitrange_getindices(nondual(a), indices))
6462
end
6563

6664
function blockedunitrange_getindices(
6765
a::GradedUnitRangeDual,
6866
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
6967
)
70-
v = mortar(map(b -> a[b], blocks(indices)))
71-
# GradedOneTo appears in mortar
72-
# flip v axis to preserve dual information
68+
# dual v axis to preserve dual information
7369
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]]))
74-
return flip_blockvector(v)
70+
return dual_axes(blockedunitrange_getindices(nondual(a), indices))
7571
end
7672

7773
function blockedunitrange_getindices(
7874
a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
7975
)
80-
# Without converting `indices` to `Vector`,
81-
# mapping `indices` outputs a `BlockVector`
82-
# which is harder to reason about.
83-
vblocks = map(index -> a[index], Vector(indices))
84-
# We pass `length.(blocks)` to `mortar` in order
85-
# to pass block labels to the axes of the output,
86-
# if they exist. This makes it so that
87-
# `only(axes(a[indices])) isa `GradedUnitRange`
88-
# if `a isa `GradedUnitRange`, for example.
89-
90-
v = mortar(vblocks, length.(vblocks))
91-
# GradedOneTo appears in mortar
92-
# flip v axis to preserve dual information
76+
# dual v axis to preserve dual information
9377
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)]))
94-
return flip_blockvector(v)
78+
return dual_axes(blockedunitrange_getindices(nondual(a), indices))
9579
end
9680

9781
# Fixes ambiguity error.
98-
# TODO: Write this in terms of `blockedunitrange_getindices(dual(a), indices)`.
9982
function blockedunitrange_getindices(
10083
a::GradedUnitRangeDual, indices::AbstractBlockVector{<:Block{1}}
10184
)
102-
blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices))
103-
# We pass `length.(blks)` to `mortar` in order
104-
# to pass block labels to the axes of the output,
105-
# if they exist. This makes it so that
106-
# `only(axes(a[indices])) isa `GradedUnitRange`
107-
# if `a isa `GradedUnitRange`, for example.
108-
v = mortar(blks, labelled_length.(blks))
109-
return flip_blockvector(v)
110-
end
111-
112-
function flip_blockvector(v::BlockVector)
113-
block_axes = flip.(axes(v))
114-
flipped = mortar(vec.(blocks(v)), block_axes)
115-
return flipped
85+
v = blockedunitrange_getindices(nondual(a), indices)
86+
# v elements are not dualled by dual_axes due to different structure.
87+
# take element dual here.
88+
return dual_axes(dual.(v))
89+
end
90+
91+
function dual_axes(v::BlockVector)
92+
# dual both v elements and v axes
93+
block_axes = dual.(axes(v))
94+
return mortar(dual.(blocks(v)), block_axes)
11695
end
11796

11897
Base.axes(a::GradedUnitRangeDual) = axes(nondual(a))

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,35 @@ end
219219
@test label(ad[Block(2)]) == U1(-1)
220220
@test label(ad[Block(2)[1:1]]) == U1(-1)
221221

222-
I = mortar([Block(2)[1:1]])
223-
g = ad[I]
224-
@test length(g) == 1
225-
@test label(first(g)) == U1(-1)
226-
@test isdual(g[Block(1)])
222+
v = ad[[Block(2)[1:1]]]
223+
@test v isa AbstractVector{LabelledInteger{Int64,U1}}
224+
@test length(v) == 1
225+
@test label(first(v)) == U1(-1)
226+
@test unlabel(first(v)) == 3
227+
@test isdual(v[Block(1)])
228+
@test isdual(axes(v, 1))
229+
@test blocklabels(axes(v, 1)) == [U1(-1)]
227230

228-
@test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)])
229-
@test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]])
231+
v = ad[mortar([Block(2)[1:1]])]
232+
@test v isa AbstractVector{LabelledInteger{Int64,U1}}
233+
@test isdual(axes(v, 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]])
234+
@test label(first(v)) == U1(-1)
235+
@test unlabel(first(v)) == 3
236+
@test blocklabels(axes(v, 1)) == [U1(-1)]
237+
238+
v = ad[[Block(2)]]
239+
@test v isa AbstractVector{LabelledInteger{Int64,U1}}
240+
@test isdual(axes(v, 1)) # used in view(::BlockSparseVector, [Block(1)])
241+
@test label(first(v)) == U1(-1)
242+
@test unlabel(first(v)) == 3
243+
@test blocklabels(axes(v, 1)) == [U1(-1)]
244+
245+
v = ad[mortar([[Block(2)], [Block(1)]])]
246+
@test v isa AbstractVector{LabelledInteger{Int64,U1}}
247+
@test isdual(axes(v, 1))
248+
@test label(first(v)) == U1(-1)
249+
@test unlabel(first(v)) == 3
250+
@test blocklabels(axes(v, 1)) == [U1(-1), U1(0)]
230251
end
231252
end
232253

0 commit comments

Comments
 (0)