Skip to content

Commit c60c080

Browse files
committed
fix dual when slicing with Vector
1 parent 5d047a8 commit c60c080

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@eval module $(gensym())
22
using Compat: Returns
3-
using Test: @test, @testset, @test_broken
3+
using Test: @test, @testset
44
using BlockArrays:
55
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
66
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
@@ -217,10 +217,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
217217
@test size(a[I, I]) == (1, 1)
218218
@test isdual(axes(a[I, :], 2))
219219
@test isdual(axes(a[:, I], 1))
220-
@test_broken isdual(axes(a[I, :], 1))
221-
@test_broken isdual(axes(a[:, I], 2))
222-
@test_broken isdual(axes(a[I, I], 1))
223-
@test_broken isdual(axes(a[I, I], 2))
220+
@test isdual(axes(a[I, :], 1))
221+
@test isdual(axes(a[:, I], 2))
222+
@test isdual(axes(a[I, I], 1))
223+
@test isdual(axes(a[I, I], 2))
224224
end
225225

226226
@testset "dual GradedUnitRange" begin
@@ -243,10 +243,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
243243
@test size(a[I, I]) == (1, 1)
244244
@test isdual(axes(a[I, :], 2))
245245
@test isdual(axes(a[:, I], 1))
246-
@test_broken isdual(axes(a[I, :], 1))
247-
@test_broken isdual(axes(a[:, I], 2))
248-
@test_broken isdual(axes(a[I, I], 1))
249-
@test_broken isdual(axes(a[I, I], 2))
246+
@test isdual(axes(a[I, :], 1))
247+
@test isdual(axes(a[:, I], 2))
248+
@test isdual(axes(a[I, I], 1))
249+
@test isdual(axes(a[I, I], 2))
250250
end
251251

252252
@testset "dual BlockedUnitRange" begin # self dual

NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,46 @@ function gradedunitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices)
5959
end
6060

6161
# TODO: Move this to a `BlockArraysExtensions` library.
62-
function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}})
62+
function blockedunitrange_getindices(
63+
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
64+
)
6365
return gradedunitrangedual_getindices_blocks(a, indices)
6466
end
6567

6668
function blockedunitrange_getindices(
67-
a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}
69+
a::GradedUnitRangeDual,
70+
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
6871
)
69-
return gradedunitrangedual_getindices_blocks(a, indices)
72+
arr = mortar(map(b -> a[b], blocks(indices)))
73+
# GradedOneTo appears in mortar
74+
# flip arr axis to preserve dual information
75+
# axes(arr) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]]))
76+
# TODO way to create BlockArray with specified axis without relying on internal?
77+
block_axes = (flip(only(axes(arr))),)
78+
flipped = BlockArrays._BlockArray(vec.(blocks(arr)), block_axes)
79+
return flipped
80+
end
81+
82+
function blockedunitrange_getindices(
83+
a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
84+
)
85+
# Without converting `indices` to `Vector`,
86+
# mapping `indices` outputs a `BlockVector`
87+
# which is harder to reason about.
88+
vblocks = map(index -> a[index], Vector(indices))
89+
# We pass `length.(blocks)` to `mortar` in order
90+
# to pass block labels to the axes of the output,
91+
# if they exist. This makes it so that
92+
# `only(axes(a[indices])) isa `GradedUnitRange`
93+
# if `a isa `GradedUnitRange`, for example.
94+
95+
arr = mortar(vblocks, length.(vblocks))
96+
# GradedOneTo appears in mortar
97+
# axes(arr) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]]))
98+
# TODO way to create BlockArray with specified axis without relying on internal?
99+
block_axes = (flip(only(axes(arr))),)
100+
flipped = BlockArrays._BlockArray(vec.(blocks(arr)), block_axes)
101+
return flipped
70102
end
71103

72104
Base.axes(a::GradedUnitRangeDual) = axes(nondual(a))

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ end
131131
@test length(g) == 1
132132
@test label(first(g)) == U1(-1)
133133
@test isdual(g[Block(1)])
134+
135+
@test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)])
136+
@test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]])
134137
end
135138
end
136139

0 commit comments

Comments
 (0)