Skip to content

Commit 8d4a14f

Browse files
authored
[GradedAxes] Introduce GradedUnitRangeDual (#1531)
1 parent 360f95e commit 8d4a14f

File tree

3 files changed

+170
-21
lines changed

3 files changed

+170
-21
lines changed

ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
@eval module $(gensym())
22
using Compat: Returns
33
using Test: @test, @testset, @test_broken
4-
using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize
4+
using BlockArrays:
5+
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
56
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
67
using NDTensors.GradedAxes:
7-
GradedAxes, GradedOneTo, UnitRangeDual, blocklabels, dual, gradedrange
8+
GradedAxes,
9+
GradedOneTo,
10+
GradedUnitRange,
11+
GradedUnitRangeDual,
12+
blocklabels,
13+
dual,
14+
gradedrange,
15+
isdual
816
using NDTensors.LabelledNumbers: label
917
using NDTensors.SparseArrayInterface: nstored
1018
using NDTensors.TensorAlgebra: fusedims, splitdims
19+
using LinearAlgebra: adjoint
1120
using Random: randn!
1221
function blockdiagonal!(f, a::AbstractArray)
1322
for i in 1:minimum(blocksize(a))
@@ -31,15 +40,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3140
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
3241
a = BlockSparseArray{elt}(d1, d2, d1, d2)
3342
blockdiagonal!(randn!, a)
43+
@test axes(a, 1) isa GradedOneTo
44+
@test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo
3445

3546
for b in (a + a, 2 * a)
3647
@test size(b) == (4, 4, 4, 4)
3748
@test blocksize(b) == (2, 2, 2, 2)
3849
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
3950
@test nstored(b) == 32
4051
@test block_nstored(b) == 2
41-
# TODO: Have to investigate why this fails
42-
# on Julia v1.6, or drop support for v1.6.
4352
for i in 1:ndims(a)
4453
@test axes(b, i) isa GradedOneTo
4554
end
@@ -103,16 +112,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
103112
@test blocksize(m) == (3, 3)
104113
@test a == splitdims(m, (d1, d2), (d1, d2))
105114
end
115+
106116
@testset "dual axes" begin
107117
r = gradedrange([U1(0) => 2, U1(1) => 2])
108118
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
109119
a = BlockSparseArray{elt}(ax...)
110120
@views for b in [Block(1, 1), Block(2, 2)]
111121
a[b] = randn(elt, size(a[b]))
112122
end
113-
# TODO: Define and use `isdual` here.
114123
for dim in 1:ndims(a)
115124
@test typeof(ax[dim]) === typeof(axes(a, dim))
125+
@test isdual(ax[dim]) == isdual(axes(a, dim))
116126
end
117127
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
118128
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]
@@ -130,41 +140,149 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
130140
@test a[I] == a_dense[I]
131141
end
132142
@test axes(a') == dual.(reverse(axes(a)))
133-
# TODO: Define and use `isdual` here.
134-
@test typeof(axes(a', 1)) === typeof(dual(axes(a, 2)))
135-
@test typeof(axes(a', 2)) === typeof(dual(axes(a, 1)))
143+
144+
@test isdual(axes(a', 1)) isdual(axes(a, 2))
145+
@test isdual(axes(a', 2)) isdual(axes(a, 1))
136146
@test isnothing(show(devnull, MIME("text/plain"), a))
137147

138148
# Check preserving dual in tensor algebra.
139149
for b in (a + a, 2 * a, 3 * a - a)
140150
@test Array(b) 2 * Array(a)
141-
# TODO: Define and use `isdual` here.
142151
for dim in 1:ndims(a)
143-
@test typeof(axes(b, dim)) === typeof(axes(b, dim))
152+
@test isdual(axes(b, dim)) == isdual(axes(a, dim))
144153
end
145154
end
146155

147156
@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
148157
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
149158
end
150159

160+
@testset "GradedOneTo" begin
161+
r = gradedrange([U1(0) => 2, U1(1) => 2])
162+
a = BlockSparseArray{elt}(r, r)
163+
@views for i in [Block(1, 1), Block(2, 2)]
164+
a[i] = randn(elt, size(a[i]))
165+
end
166+
b = 2 * a
167+
@test block_nstored(b) == 2
168+
@test Array(b) == 2 * Array(a)
169+
for i in 1:2
170+
@test axes(b, i) isa GradedOneTo
171+
@test axes(a[:, :], i) isa GradedOneTo
172+
end
173+
174+
I = [Block(1)[1:1]]
175+
@test a[I, :] isa AbstractBlockArray
176+
@test a[:, I] isa AbstractBlockArray
177+
@test size(a[I, I]) == (1, 1)
178+
@test !isdual(axes(a[I, I], 1))
179+
end
180+
181+
@testset "GradedUnitRange" begin
182+
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
183+
a = BlockSparseArray{elt}(r, r)
184+
@views for i in [Block(1, 1), Block(2, 2)]
185+
a[i] = randn(elt, size(a[i]))
186+
end
187+
b = 2 * a
188+
@test block_nstored(b) == 2
189+
@test Array(b) == 2 * Array(a)
190+
for i in 1:2
191+
@test axes(b, i) isa GradedUnitRange
192+
@test axes(a[:, :], i) isa GradedUnitRange
193+
end
194+
195+
I = [Block(1)[1:1]]
196+
@test a[I, :] isa AbstractBlockArray
197+
@test axes(a[I, :], 1) isa GradedOneTo
198+
@test axes(a[I, :], 2) isa GradedUnitRange
199+
200+
@test a[:, I] isa AbstractBlockArray
201+
@test axes(a[:, I], 2) isa GradedOneTo
202+
@test axes(a[:, I], 1) isa GradedUnitRange
203+
@test size(a[I, I]) == (1, 1)
204+
@test !isdual(axes(a[I, I], 1))
205+
end
206+
151207
# Test case when all axes are dual.
152-
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
208+
@testset "dual GradedOneTo" begin
209+
r = gradedrange([U1(-1) => 2, U1(1) => 2])
153210
a = BlockSparseArray{elt}(dual(r), dual(r))
154211
@views for i in [Block(1, 1), Block(2, 2)]
155212
a[i] = randn(elt, size(a[i]))
156213
end
157214
b = 2 * a
158215
@test block_nstored(b) == 2
159216
@test Array(b) == 2 * Array(a)
160-
for ax in axes(b)
161-
@test ax isa UnitRangeDual
217+
for i in 1:2
218+
@test axes(b, i) isa GradedUnitRangeDual
219+
@test axes(a[:, :], i) isa GradedUnitRangeDual
162220
end
221+
I = [Block(1)[1:1]]
222+
@test a[I, :] isa AbstractBlockArray
223+
@test a[:, I] isa AbstractBlockArray
224+
@test size(a[I, I]) == (1, 1)
225+
@test isdual(axes(a[I, :], 2))
226+
@test isdual(axes(a[:, I], 1))
227+
@test_broken isdual(axes(a[I, :], 1))
228+
@test_broken isdual(axes(a[:, I], 2))
229+
@test_broken isdual(axes(a[I, I], 1))
230+
@test_broken isdual(axes(a[I, I], 2))
231+
end
232+
233+
@testset "dual GradedUnitRange" begin
234+
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
235+
a = BlockSparseArray{elt}(dual(r), dual(r))
236+
@views for i in [Block(1, 1), Block(2, 2)]
237+
a[i] = randn(elt, size(a[i]))
238+
end
239+
b = 2 * a
240+
@test block_nstored(b) == 2
241+
@test Array(b) == 2 * Array(a)
242+
for i in 1:2
243+
@test axes(b, i) isa GradedUnitRangeDual
244+
@test axes(a[:, :], i) isa GradedUnitRangeDual
245+
end
246+
247+
I = [Block(1)[1:1]]
248+
@test a[I, :] isa AbstractBlockArray
249+
@test a[:, I] isa AbstractBlockArray
250+
@test size(a[I, I]) == (1, 1)
251+
@test isdual(axes(a[I, :], 2))
252+
@test isdual(axes(a[:, I], 1))
253+
@test_broken isdual(axes(a[I, :], 1))
254+
@test_broken isdual(axes(a[:, I], 2))
255+
@test_broken isdual(axes(a[I, I], 1))
256+
@test_broken isdual(axes(a[I, I], 2))
163257
end
164258

165-
# Test case when all axes are dual
166-
# from taking the adjoint.
167-
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
259+
@testset "dual BlockedUnitRange" begin # self dual
260+
r = blockedrange([2, 2])
261+
a = BlockSparseArray{elt}(dual(r), dual(r))
262+
@views for i in [Block(1, 1), Block(2, 2)]
263+
a[i] = randn(elt, size(a[i]))
264+
end
265+
b = 2 * a
266+
@test block_nstored(b) == 2
267+
@test Array(b) == 2 * Array(a)
268+
@test a[:, :] isa BlockSparseArray
269+
for i in 1:2
270+
@test axes(b, i) isa BlockedOneTo
271+
@test axes(a[:, :], i) isa BlockedOneTo
272+
end
273+
274+
I = [Block(1)[1:1]]
275+
@test a[I, :] isa BlockSparseArray
276+
@test a[:, I] isa BlockSparseArray
277+
@test size(a[I, I]) == (1, 1)
278+
@test !isdual(axes(a[I, I], 1))
279+
end
280+
281+
# Test case when all axes are dual from taking the adjoint.
282+
for r in (
283+
gradedrange([U1(0) => 2, U1(1) => 2]),
284+
gradedrange([U1(0) => 2, U1(1) => 2])[begin:end],
285+
)
168286
a = BlockSparseArray{elt}(r, r)
169287
@views for i in [Block(1, 1), Block(2, 2)]
170288
a[i] = randn(elt, size(a[i]))
@@ -173,8 +291,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
173291
@test block_nstored(b) == 2
174292
@test Array(b) == 2 * Array(a)'
175293
for ax in axes(b)
176-
@test ax isa UnitRangeDual
294+
@test ax isa typeof(dual(r))
177295
end
296+
297+
I = [Block(1)[1:1]]
298+
@test size(b[I, :]) == (1, 4)
299+
@test size(b[:, I]) == (4, 1)
300+
@test size(b[I, I]) == (1, 1)
178301
end
179302
end
180303
@testset "Matrix multiplication" begin

src/abstractblocksparsearray/views.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
using BlockArrays:
2-
BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock
2+
AbstractBlockedUnitRange,
3+
BlockArrays,
4+
Block,
5+
BlockIndexRange,
6+
BlockedVector,
7+
blocklength,
8+
blocksize,
9+
viewblock
310

411
# This splits `BlockIndexRange{N}` into
512
# `NTuple{N,BlockIndexRange{1}}`.
@@ -191,7 +198,9 @@ function to_blockindexrange(
191198
# work right now.
192199
return blocks(a.blocks)[Int(I)]
193200
end
194-
function to_blockindexrange(a::Base.Slice{<:BlockedOneTo{<:Integer}}, I::Block{1})
201+
function to_blockindexrange(
202+
a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1}
203+
)
195204
@assert I in only(blockaxes(a.indices))
196205
return I
197206
end

test/test_basics.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@ using BlockArrays:
1515
blocksizes,
1616
mortar
1717
using Compat: @compat
18-
using LinearAlgebra: mul!
18+
using LinearAlgebra: Adjoint, mul!
1919
using NDTensors.BlockSparseArrays:
20-
@view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view!
20+
@view!,
21+
BlockSparseArray,
22+
BlockView,
23+
block_nstored,
24+
block_reshape,
25+
block_stored_indices,
26+
view!
2127
using NDTensors.SparseArrayInterface: nstored
2228
using NDTensors.TensorAlgebra: contract
2329
using Test: @test, @test_broken, @test_throws, @testset
@@ -44,6 +50,17 @@ include("TestBlockSparseArraysUtils.jl")
4450
a[Block(2, 2)] = randn(elt, 3, 3)
4551
@test a[2:4, 4] == Array(a)[2:4, 4]
4652
@test_broken a[4, 2:4]
53+
54+
@test a[Block(1), :] isa BlockSparseArray{elt}
55+
@test adjoint(a) isa Adjoint{elt,<:BlockSparseArray}
56+
@test_broken adjoint(a)[Block(1), :] isa Adjoint{elt,<:BlockSparseArray}
57+
# could also be directly a BlockSparseArray
58+
59+
a = BlockSparseArray{elt}([1], [1, 1])
60+
a[1, 2] = 1
61+
@test [a[Block(Tuple(it))] for it in eachindex(block_stored_indices(a))] isa Vector
62+
ah = adjoint(a)
63+
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
4764
end
4865
@testset "Basics" begin
4966
a = BlockSparseArray{elt}([2, 3], [2, 3])

0 commit comments

Comments
 (0)