Skip to content

Commit c1ec06d

Browse files
committed
fix slicing with BlockVector
1 parent 38a21f6 commit c1ec06d

File tree

4 files changed

+68
-20
lines changed

4 files changed

+68
-20
lines changed

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
112112
@test blocksize(m) == (3, 3)
113113
@test a == splitdims(m, (d1, d2), (d1, d2))
114114
end
115+
115116
@testset "dual axes" begin
116117
r = gradedrange([U1(0) => 2, U1(1) => 2])
117118
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
118119
a = BlockSparseArray{elt}(ax...)
119120
@views for b in [Block(1, 1), Block(2, 2)]
120121
a[b] = randn(elt, size(a[b]))
121122
end
122-
# TODO: Define and use `isdual` here.
123123
for dim in 1:ndims(a)
124124
@test typeof(ax[dim]) === typeof(axes(a, dim))
125125
@test isdual(ax[dim]) == isdual(axes(a, dim))
@@ -176,7 +176,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
176176
@test a[I, :] isa AbstractBlockArray
177177
@test a[:, I] isa AbstractBlockArray
178178
@test size(a[I, I]) == (1, 1)
179-
@test !GradedAxes.isdual(axes(a[I, I], 1))
179+
@test !isdual(axes(a[I, I], 1))
180180
end
181181

182182
@testset "GradedUnitRange" begin
@@ -190,14 +190,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
190190
@test Array(b) == 2 * Array(a)
191191
for i in 1:2
192192
@test axes(b, i) isa GradedUnitRange
193-
@test_broken axes(a[:, :], i) isa GradedUnitRange
193+
@test axes(a[:, :], i) isa GradedUnitRange
194194
end
195195

196196
I = [Block(1)[1:1]]
197-
@test_broken a[I, :] isa AbstractBlockArray
198-
@test_broken a[:, I] isa AbstractBlockArray
197+
@test a[I, :] isa AbstractBlockArray
198+
@test axes(a[I, :], 1) isa GradedOneTo
199+
@test axes(a[I, :], 2) isa GradedUnitRange
200+
201+
@test a[:, I] isa AbstractBlockArray
202+
@test axes(a[:, I], 2) isa GradedOneTo
203+
@test axes(a[:, I], 1) isa GradedUnitRange
199204
@test size(a[I, I]) == (1, 1)
200-
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
205+
@test !isdual(axes(a[I, I], 1))
201206
end
202207

203208
# Test case when all axes are dual.
@@ -212,13 +217,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
212217
@test Array(b) == 2 * Array(a)
213218
for i in 1:2
214219
@test axes(b, i) isa GradedUnitRangeDual
215-
@test_broken axes(a[:, :], i) isa GradedUnitRangeDual
220+
@test axes(a[:, :], i) isa GradedUnitRangeDual
216221
end
217222
I = [Block(1)[1:1]]
218-
@test_broken a[I, :] isa AbstractBlockArray
219-
@test_broken a[:, I] isa AbstractBlockArray
223+
@test a[I, :] isa AbstractBlockArray
224+
@test a[:, I] isa AbstractBlockArray
220225
@test size(a[I, I]) == (1, 1)
221-
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
226+
@test isdual(axes(a[I, :], 2))
227+
@test isdual(axes(a[:, I], 1))
228+
@test_broken isdual(axes(a[I, :], 1))
229+
@test_broken isdual(axes(a[:, I], 2))
230+
@test_broken isdual(axes(a[I, I], 1))
231+
@test_broken isdual(axes(a[I, I], 2))
222232
end
223233

224234
@testset "dual GradedUnitRange" begin
@@ -232,14 +242,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
232242
@test Array(b) == 2 * Array(a)
233243
for i in 1:2
234244
@test axes(b, i) isa GradedUnitRangeDual
235-
@test_broken axes(a[:, :], i) isa GradedUnitRangeDual
245+
@test axes(a[:, :], i) isa GradedUnitRangeDual
236246
end
237247

238248
I = [Block(1)[1:1]]
239-
@test_broken a[I, :] isa AbstractBlockArray
240-
@test_broken a[:, I] isa AbstractBlockArray
249+
@test a[I, :] isa AbstractBlockArray
250+
@test a[:, I] isa AbstractBlockArray
241251
@test size(a[I, I]) == (1, 1)
242-
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
252+
@test isdual(axes(a[I, :], 2))
253+
@test isdual(axes(a[:, I], 1))
254+
@test_broken isdual(axes(a[I, :], 1))
255+
@test_broken isdual(axes(a[:, I], 2))
256+
@test_broken isdual(axes(a[I, I], 1))
257+
@test_broken isdual(axes(a[I, I], 2))
243258
end
244259

245260
@testset "dual BlockedUnitRange" begin # self dual
@@ -261,7 +276,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
261276
@test a[I, :] isa BlockSparseArray
262277
@test a[:, I] isa BlockSparseArray
263278
@test size(a[I, I]) == (1, 1)
264-
@test !GradedAxes.isdual(axes(a[I, I], 1))
279+
@test !isdual(axes(a[I, I], 1))
265280
end
266281

267282
# Test case when all axes are dual from taking the adjoint.
@@ -281,8 +296,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
281296
end
282297

283298
I = [Block(1)[1:1]]
284-
@test_broken size(b[I, :]) == (1, 4)
285-
@test_broken size(b[:, I]) == (4, 1)
299+
@test size(b[I, :]) == (1, 4)
300+
@test size(b[:, I]) == (4, 1)
286301
@test size(b[I, I]) == (1, 1)
287302
end
288303
end

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,13 @@ function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::Vector
222222
return map(index -> a[index], indices)
223223
end
224224

225+
function blockedunitrange_getindices(
226+
a::AbstractGradedUnitRange,
227+
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
228+
)
229+
return mortar(map(b -> a[b], blocks(indices)))
230+
end
231+
225232
function blockedunitrange_getindices(a::AbstractGradedUnitRange, index)
226233
return labelled(unlabel_blocks(a)[index], get_label(a, index))
227234
end
@@ -345,6 +352,9 @@ function BlockArrays.combine_blockaxes(a::GradedUnitRange, b::GradedUnitRange)
345352
return GradedUnitRange(new_first, new_blocklasts)
346353
end
347354

355+
# preserve axes in SubArray
356+
Base.axes(S::Base.Slice{<:AbstractGradedUnitRange}) = (S.indices,)
357+
348358
# Version of length that checks that all blocks have the same label
349359
# and returns a labelled length with that label.
350360
function labelled_length(a::AbstractBlockVector{<:Integer})

NDTensors/src/lib/GradedAxes/test/test_basics.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ using BlockArrays:
1010
blocklength,
1111
blocklengths,
1212
blocks,
13-
combine_blockaxes
13+
combine_blockaxes,
14+
mortar
1415
using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange
1516
using NDTensors.LabelledNumbers:
1617
LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel
@@ -97,6 +98,7 @@ end
9798
@test blocklabels(only(axes(a))) == blocklabels(a)
9899

99100
@test combine_blockaxes(a, a) isa GradedOneTo
101+
@test axes(Base.Slice(a)) isa Tuple{typeof(a)}
100102
end
101103

102104
# Slicing operations
@@ -145,6 +147,7 @@ end
145147
@test length(ax) == length(a)
146148
@test blocklengths(ax) == blocklengths(a)
147149
@test blocklabels(ax) == blocklabels(a)
150+
@test axes(Base.Slice(a)) isa Tuple{typeof(a)}
148151

149152
x = gradedrange(["x" => 2, "y" => 3])
150153
a = x[2:4][1:2]
@@ -227,5 +230,17 @@ end
227230
# once `blocklengths(::BlockVector)` is defined.
228231
@test blocklengths(ax) == [2, 2]
229232
@test blocklabels(ax) == blocklabels(a)
233+
234+
x = gradedrange(["x" => 2, "y" => 3])
235+
I = mortar([Block(1)[1:1]])
236+
a = x[I]
237+
@test length(a) == 1
238+
@test label(first(a)) == "x"
239+
240+
x = gradedrange(["x" => 2, "y" => 3])[1:5]
241+
I = mortar([Block(1)[1:1]])
242+
a = x[I]
243+
@test length(a) == 1
244+
@test label(first(a)) == "x"
230245
end
231246
end

NDTensors/src/lib/GradedAxes/test/test_dual.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ using BlockArrays:
1010
blocklength,
1111
blocklengths,
1212
blocks,
13-
findblock
13+
findblock,
14+
mortar
1415
using NDTensors.GradedAxes:
1516
AbstractGradedUnitRange,
1617
GradedAxes,
@@ -26,7 +27,7 @@ using NDTensors.GradedAxes:
2627
isdual,
2728
nondual
2829
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal
29-
using Test: @test, @testset
30+
using Test: @test, @test_broken, @testset
3031
struct U1
3132
n::Int
3233
end
@@ -75,6 +76,7 @@ end
7576

7677
@test isdual(ad)
7778
@test !isdual(a)
79+
@test axes(Base.Slice(a)) isa Tuple{typeof(a)}
7880

7981
@test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))]
8082
@test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))]
@@ -106,6 +108,12 @@ end
106108
@test blocklength(blockmergesortperm(ad)) == 2
107109
@test blockmergesortperm(a) == [Block(1), Block(2)]
108110
@test blockmergesortperm(ad) == [Block(1), Block(2)]
111+
112+
I = mortar([Block(2)[1:1]])
113+
g = ad[I]
114+
@test length(g) == 1
115+
@test label(first(g)) == U1(-1)
116+
@test_broken isdual(g[Block(1)])
109117
end
110118
end
111119

0 commit comments

Comments
 (0)