Skip to content

Commit d9645fe

Browse files
committed
Fix broken tests
1 parent 98a000f commit d9645fe

File tree

3 files changed

+82
-22
lines changed

3 files changed

+82
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ KroneckerArraysTensorProductsExt = "TensorProducts"
2727
[compat]
2828
Adapt = "4.3"
2929
BlockArrays = "1.6"
30-
BlockSparseArrays = "0.9, 0.10"
30+
BlockSparseArrays = "0.9, 0.10.3"
3131
DerivableInterfaces = "0.5.3"
3232
DiagonalArrays = "0.3.11"
3333
FillArrays = "1.13"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
2323
return blockrange(map(cartesianrange, bs))
2424
end
2525

26+
using BlockArrays: BlockArrays, mortar
27+
using BlockSparseArrays: blockrange
28+
using KroneckerArrays: CartesianProductUnitRange
29+
# Makes sure that `mortar` results in a `BlockVector` with the correct
30+
# axes, otherwise the axes would not preserve the Kronecker structure.
31+
# This is helpful when indexing `BlockUnitRange`, for example:
32+
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.7.1/src/blockaxis.jl#L540-L547
33+
function BlockArrays.mortar(blocks::AbstractVector{<:CartesianProductUnitRange})
34+
return mortar(blocks, (blockrange(map(Base.axes1, blocks)),))
35+
end
36+
2637
using BlockArrays: AbstractBlockedUnitRange
2738
using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis
2839
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2, _similar

test/test_blocksparsearrays.jl

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ arrayts = (Array, JLArray)
3232
@test blockisequal(arg1(r), blockedrange([2, 3]))
3333
@test blockisequal(arg2(r), blockedrange([3, 4]))
3434

35+
r = blockrange([2 × 3, 3 × 4])
36+
r′ = r[Block.([2, 1])]
37+
@test r′[Block(1)] cartesianrange(3 × 4, 7:18)
38+
@test r′[Block(2)] cartesianrange(2 × 3, 1:6)
39+
@test eachblockaxis(r′)[1] cartesianrange(3, 4)
40+
@test eachblockaxis(r′)[2] cartesianrange(2, 3)
41+
3542
dev = adapt(arrayt)
3643
r = blockrange([2 × 2, 3 × 3])
3744
d = Dict(
@@ -137,13 +144,8 @@ arrayts = (Array, JLArray)
137144
@test_broken inv(a)
138145
end
139146

140-
if arrayt === Array
141-
u, s, v = svd_compact(a)
142-
@test Array(u * s * v) Array(a)
143-
else
144-
# Broken on GPU.
145-
@test_broken svd_compact(a)
146-
end
147+
u, s, v = svd_compact(a)
148+
@test Array(u * s * v) Array(a)
147149

148150
b = a[Block.(1:2), Block(2)]
149151
@test b[Block(1)] == a[Block(1, 2)]
@@ -236,59 +238,106 @@ end
236238
@test_broken copy(b)
237239
@test_broken b[Block(1, 2)]
238240

241+
r = blockrange([2 × 2, 3 × 3])
242+
d = Dict(
243+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
244+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
245+
)
246+
a = dev(blocksparse(d, (r, r)))
239247
b = @constinferred a * a
240248
@test typeof(b) === typeof(a)
241249
@test Array(b) Array(a) * Array(a)
242250

251+
r = blockrange([2 × 2, 3 × 3])
252+
d = Dict(
253+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
254+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
255+
)
256+
a = dev(blocksparse(d, (r, r)))
243257
# Type inference is broken for this operation.
244258
# b = @constinferred a + a
245259
b = a + a
246260
@test typeof(b) === typeof(a)
247261
@test Array(b) Array(a) + Array(a)
248262

263+
r = blockrange([2 × 2, 3 × 3])
264+
d = Dict(
265+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
266+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
267+
)
268+
a = dev(blocksparse(d, (r, r)))
249269
# Type inference is broken for this operation.
250270
# b = @constinferred 3a
251271
b = 3a
252272
@test typeof(b) === typeof(a)
253273
@test Array(b) 3Array(a)
254274

275+
r = blockrange([2 × 2, 3 × 3])
276+
d = Dict(
277+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
278+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
279+
)
280+
a = dev(blocksparse(d, (r, r)))
255281
# Type inference is broken for this operation.
256282
# b = @constinferred a / 3
257283
b = a / 3
258284
@test typeof(b) === typeof(a)
259285
@test Array(b) Array(a) / 3
260286

287+
r = blockrange([2 × 2, 3 × 3])
288+
d = Dict(
289+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
290+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
291+
)
292+
a = dev(blocksparse(d, (r, r)))
261293
@test @constinferred(norm(a)) norm(Array(a))
262294

295+
r = blockrange([2 × 2, 3 × 3])
296+
d = Dict(
297+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
298+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
299+
)
300+
a = dev(blocksparse(d, (r, r)))
263301
if arrayt === Array
264302
b = @constinferred exp(a)
265303
@test Array(b) exp(Array(a))
266304
else
267305
@test_broken exp(a)
268306
end
269307

270-
## if VERSION < v"1.11-" && elt <: Complex
271-
## # Broken because of type stability issue in Julia v1.10.
272-
## @test_broken svd_compact(a)
273-
if arrayt === Array
274-
u, s, v = svd_compact(a)
275-
@test u * s * v a
276-
@test blocktype(u) >: blocktype(u)
277-
@test eltype(u) === eltype(a)
278-
@test blocktype(v) >: blocktype(a)
279-
@test eltype(v) === eltype(a)
280-
@test eltype(s) === real(eltype(a))
281-
else
282-
@test_broken svd_compact(a)
283-
end
308+
r = blockrange([2 × 2, 3 × 3])
309+
d = Dict(
310+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
311+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
312+
)
313+
a = dev(blocksparse(d, (r, r)))
314+
u, s, v = svd_compact(a)
315+
@test u * s * v a
316+
@test blocktype(u) >: blocktype(u)
317+
@test eltype(u) === eltype(a)
318+
@test blocktype(v) >: blocktype(a)
319+
@test eltype(v) === eltype(a)
320+
@test eltype(s) === real(eltype(a))
284321

322+
r = blockrange([2 × 2, 3 × 3])
323+
d = Dict(
324+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
325+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
326+
)
327+
a = dev(blocksparse(d, (r, r)))
285328
if arrayt === Array
286329
@test Array(inv(a)) inv(Array(a))
287330
else
288331
# Broken on GPU.
289332
@test_broken inv(a)
290333
end
291334

335+
r = blockrange([2 × 2, 3 × 3])
336+
d = Dict(
337+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
338+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
339+
)
340+
a = dev(blocksparse(d, (r, r)))
292341
# Broken operations
293342
b = a[Block.(1:2), Block(2)]
294343
@test b[Block(1)] == a[Block(1, 2)]

0 commit comments

Comments
 (0)