Skip to content

Commit 75be561

Browse files
committed
use blockedunitrange_getindex(nondual(a)
1 parent 230ddae commit 75be561

File tree

2 files changed

+11
-33
lines changed

2 files changed

+11
-33
lines changed

NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,61 +57,38 @@ end
5757
function blockedunitrange_getindices(
5858
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
5959
)
60-
a_indices = getindex(nondual(a), indices)
61-
v = mortar(dual.(blocks(a_indices)))
6260
# flip v to stay consistent with other cases where axes(v) are used
63-
return flip_axes(v)
61+
return dual_axes(blockedunitrange_getindices(nondual(a), indices))
6462
end
6563

6664
function blockedunitrange_getindices(
6765
a::GradedUnitRangeDual,
6866
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
6967
)
70-
v = mortar(map(b -> a[b], blocks(indices)))
71-
# GradedOneTo appears in mortar
7268
# flip v axis to preserve dual information
7369
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]]))
74-
return flip_axes(v)
70+
return dual_axes(blockedunitrange_getindices(nondual(a), indices))
7571
end
7672

7773
function blockedunitrange_getindices(
7874
a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
7975
)
80-
# Without converting `indices` to `Vector`,
81-
# mapping `indices` outputs a `BlockVector`
82-
# which is harder to reason about.
83-
vblocks = map(index -> a[index], Vector(indices))
84-
# We pass `length.(blocks)` to `mortar` in order
85-
# to pass block labels to the axes of the output,
86-
# if they exist. This makes it so that
87-
# `only(axes(a[indices])) isa `GradedUnitRange`
88-
# if `a isa `GradedUnitRange`, for example.
89-
90-
v = mortar(vblocks, length.(vblocks))
91-
# GradedOneTo appears in mortar
9276
# flip v axis to preserve dual information
9377
# axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)]))
94-
return flip_axes(v)
78+
return dual_axes(blockedunitrange_getindices(nondual(a), indices))
9579
end
9680

9781
# Fixes ambiguity error.
98-
# TODO: Write this in terms of `blockedunitrange_getindices(dual(a), indices)`.
9982
function blockedunitrange_getindices(
10083
a::GradedUnitRangeDual, indices::AbstractBlockVector{<:Block{1}}
10184
)
102-
blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices))
103-
# We pass `length.(blks)` to `mortar` in order
104-
# to pass block labels to the axes of the output,
105-
# if they exist. This makes it so that
106-
# `only(axes(a[indices])) isa `GradedUnitRange`
107-
# if `a isa `GradedUnitRange`, for example.
108-
v = mortar(blks, labelled_length.(blks))
109-
return flip_axes(v)
110-
end
111-
112-
function flip_axes(v::BlockVector)
113-
block_axes = flip.(axes(v))
114-
flipped = mortar(vec.(blocks(v)), block_axes)
85+
return dual_axes(blockedunitrange_getindices(nondual(a), indices))
86+
end
87+
88+
function dual_axes(v::BlockVector)
89+
# dual both v elements and v axes
90+
block_axes = dual.(axes(v))
91+
flipped = mortar(dual.(vec.(blocks(v))), block_axes)
11592
return flipped
11693
end
11794

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ end
227227

228228
@test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)])
229229
@test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]])
230+
@test isdual(axes(ad[mortar([[Block(1)], [Block(2)]])]))
230231
end
231232
end
232233

0 commit comments

Comments
 (0)