Skip to content

Commit aa1b655

Browse files
committed
more tests
1 parent 8a353bd commit aa1b655

File tree

2 files changed

+77
-17
lines changed
  • NDTensors/src/lib/BlockSparseArrays

2 files changed

+77
-17
lines changed

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
@eval module $(gensym())
22
using Compat: Returns
33
using Test: @test, @testset, @test_broken
4-
using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize
4+
using BlockArrays:
5+
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
56
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
67
using NDTensors.GradedAxes:
7-
GradedAxes, GradedOneTo, GradedUnitRangeDual, blocklabels, dual, gradedrange
8+
GradedAxes,
9+
GradedOneTo,
10+
GradedUnitRange,
11+
GradedUnitRangeDual,
12+
blocklabels,
13+
dual,
14+
gradedrange
815
using NDTensors.LabelledNumbers: label
916
using NDTensors.SparseArrayInterface: nstored
1017
using NDTensors.TensorAlgebra: fusedims, splitdims
@@ -147,8 +154,50 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
147154
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
148155
end
149156

157+
@testset "GradedOneTo" begin
158+
r = gradedrange([U1(0) => 2, U1(1) => 2])
159+
a = BlockSparseArray{elt}(r, r)
160+
@views for i in [Block(1, 1), Block(2, 2)]
161+
a[i] = randn(elt, size(a[i]))
162+
end
163+
b = 2 * a
164+
@test block_nstored(b) == 2
165+
@test Array(b) == 2 * Array(a)
166+
for i in 1:2
167+
@test axes(b, i) isa GradedOneTo
168+
@test axes(a[:, :], i) isa GradedOneTo
169+
end
170+
171+
I = [Block(1)[1:1]]
172+
@test a[I, :] isa AbstractBlockArray
173+
@test a[:, I] isa AbstractBlockArray
174+
@test size(a[I, I]) == (1, 1)
175+
@test !GradedAxes.isdual(axes(a[I, I], 1))
176+
end
177+
178+
@testset "GradedUnitRange" begin
179+
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
180+
a = BlockSparseArray{elt}(r, r)
181+
@views for i in [Block(1, 1), Block(2, 2)]
182+
a[i] = randn(elt, size(a[i]))
183+
end
184+
b = 2 * a
185+
@test block_nstored(b) == 2
186+
@test Array(b) == 2 * Array(a)
187+
for i in 1:2
188+
@test axes(b, i) isa GradedUnitRange
189+
@test_broken axes(a[:, :], i) isa GradedUnitRange
190+
end
191+
192+
I = [Block(1)[1:1]]
193+
@test_broken a[I, :] isa AbstractBlockArray
194+
@test_broken a[:, I] isa AbstractBlockArray
195+
@test size(a[I, I]) == (1, 1)
196+
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
197+
end
198+
150199
# Test case when all axes are dual.
151-
@testset "BlockedOneTo" begin
200+
@testset "dual BlockedOneTo" begin
152201
r = gradedrange([U1(0) => 2, U1(1) => 2])
153202
a = BlockSparseArray{elt}(dual(r), dual(r))
154203
@views for i in [Block(1, 1), Block(2, 2)]
@@ -162,13 +211,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
162211
@test_broken axes(a[:, :], i) isa GradedUnitRangeDual
163212
end
164213
I = [Block(1)[1:1]]
165-
@test_broken a[I, :]
166-
@test_broken a[:, I]
214+
@test_broken a[I, :] isa AbstractBlockArray
215+
@test_broken a[:, I] isa AbstractBlockArray
167216
@test size(a[I, I]) == (1, 1)
168217
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
169218
end
170219

171-
@testset "GradedUnitRange" begin
220+
@testset "dual GradedUnitRange" begin
172221
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
173222
a = BlockSparseArray{elt}(dual(r), dual(r))
174223
@views for i in [Block(1, 1), Block(2, 2)]
@@ -183,13 +232,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
183232
end
184233

185234
I = [Block(1)[1:1]]
186-
@test_broken a[I, :]
187-
@test_broken a[:, I]
235+
@test_broken a[I, :] isa AbstractBlockArray
236+
@test_broken a[:, I] isa AbstractBlockArray
188237
@test size(a[I, I]) == (1, 1)
189238
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
190239
end
191240

192-
@testset "BlockedUnitRange" begin # self dual
241+
@testset "dual BlockedUnitRange" begin # self dual
193242
r = blockedrange([2, 2])
194243
a = BlockSparseArray{elt}(dual(r), dual(r))
195244
@views for i in [Block(1, 1), Block(2, 2)]
@@ -211,9 +260,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
211260
@test !GradedAxes.isdual(axes(a[I, I], 1))
212261
end
213262

214-
# Test case when all axes are dual
215-
# from taking the adjoint.
216-
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
263+
# Test case when all axes are dual from taking the adjoint.
264+
for r in (
265+
gradedrange([U1(0) => 2, U1(1) => 2]),
266+
gradedrange([U1(0) => 2, U1(1) => 2])[begin:end],
267+
)
217268
a = BlockSparseArray{elt}(r, r)
218269
@views for i in [Block(1, 1), Block(2, 2)]
219270
a[i] = randn(elt, size(a[i]))
@@ -226,9 +277,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
226277
end
227278

228279
I = [Block(1)[1:1]]
229-
@test size(a[I, :]) == (1, 4)
230-
@test size(a[:, I]) == (4, 1)
231-
@test size(a[I, I]) == (1, 1)
280+
@test_broken size(b[I, :]) == (1, 4)
281+
@test_broken size(b[:, I]) == (4, 1)
282+
@test size(b[I, I]) == (1, 1)
232283
end
233284
end
234285
@testset "Matrix multiplication" begin

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
using BlockArrays:
2-
BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock
2+
AbstractBlockedUnitRange,
3+
BlockArrays,
4+
Block,
5+
BlockIndexRange,
6+
BlockedVector,
7+
blocklength,
8+
blocksize,
9+
viewblock
310

411
# This splits `BlockIndexRange{N}` into
512
# `NTuple{N,BlockIndexRange{1}}`.
@@ -191,7 +198,9 @@ function to_blockindexrange(
191198
# work right now.
192199
return blocks(a.blocks)[Int(I)]
193200
end
194-
function to_blockindexrange(a::Base.Slice{<:BlockedOneTo{<:Integer}}, I::Block{1})
201+
function to_blockindexrange(
202+
a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1}
203+
)
195204
@assert I in only(blockaxes(a.indices))
196205
return I
197206
end

0 commit comments

Comments
 (0)