Skip to content

Commit 82fda96

Browse files
committed
Merge branch 'main' into vec_show
2 parents d607e9b + 10f6ebf commit 82fda96

File tree

6 files changed

+71
-18
lines changed

6 files changed

+71
-18
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.2"
4+
version = "0.3.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -34,7 +34,7 @@ Aqua = "0.8.9"
3434
ArrayLayouts = "1.10.4"
3535
BlockArrays = "1.2.0"
3636
DerivableInterfaces = "0.3.8"
37-
DiagonalArrays = "0.2.2"
37+
DiagonalArrays = "0.3"
3838
Dictionaries = "0.4.3"
3939
FillArrays = "1.13.0"
4040
GPUArraysCore = "0.1.0, 0.2"

src/abstractblocksparsearray/views.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,21 @@ function Base.view(
130130
) where {T,N}
131131
return viewblock(a, block)
132132
end
133+
134+
# Fix ambiguity error with BlockArrays.jl for slices like
135+
# `a = BlockSparseArray{Float64}(undef, [2, 2], [2, 2]); @view a[:, :]`.
136+
function Base.view(
137+
a::SubArray{
138+
T,
139+
N,
140+
<:AbstractBlockSparseArray{T,N},
141+
<:Tuple{Vararg{Union{Base.Slice,BlockSlice{<:BlockRange{1}}},N}},
142+
},
143+
block::Block{N},
144+
) where {T,N}
145+
return viewblock(a, block)
146+
end
147+
133148
function Base.view(
134149
a::SubArray{
135150
T,

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,14 +363,25 @@ function Base.replace_in_print_matrix(
363363
end
364364

365365
# attempt to catch things that wrap GPU arrays
366-
function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray)
367-
X_cpu = adapt(Array, X)
368-
if typeof(X_cpu) === typeof(X) # prevent infinite recursion
366+
function Base.print_array(io::IO, a::AnyAbstractBlockSparseArray)
367+
a_cpu = adapt(Array, a)
368+
if typeof(a_cpu) === typeof(a) # prevent infinite recursion
369369
# need to specify ndims to allow specialized code for vector/matrix
370370
@allowscalar @invoke Base.print_array(
371-
io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)}
371+
io, a_cpu::AbstractArray{eltype(a_cpu),ndims(a_cpu)}
372372
)
373-
else
374-
Base.print_array(io, X_cpu)
373+
return nothing
375374
end
375+
Base.print_array(io, a_cpu)
376+
return nothing
377+
end
378+
379+
using Adapt: Adapt, adapt
380+
function Adapt.adapt_structure(to, a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray})
381+
# In the generic definition in Adapt.jl, `parentindices(a)` are also
382+
# adapted, but is broken when the parent indices contained blocked unit
383+
# ranges since `adapt` is broken on blocked unit ranges.
384+
# TODO: Fix adapt for blocked unit ranges by making an AdaptExt for
385+
# BlockArrays.jl.
386+
return SubArray(adapt(to, parent(a)), parentindices(a))
376387
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using BlockArrays:
1313
blockcheckbounds,
1414
blockisequal,
1515
blocklengths,
16+
blocklength,
1617
blocks,
1718
findblockindex
1819
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface
@@ -419,6 +420,7 @@ end
419420

420421
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)
421422
to_blocks_indices(I::BlockIndices{<:Vector{<:Block{1}}}) = Int.(I.blocks)
423+
to_blocks_indices(I::Base.Slice{<:BlockedOneTo}) = Base.OneTo(blocklength(I.indices))
422424

423425
@interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(
424426
a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{BlockSliceCollection}}}

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Aqua = "0.8"
2727
ArrayLayouts = "1"
2828
BlockArrays = "1"
2929
BlockSparseArrays = "0.3"
30-
DiagonalArrays = "0.2"
30+
DiagonalArrays = "0.3"
3131
GPUArraysCore = "0.2"
3232
GradedUnitRanges = "0.1"
3333
JLArrays = "0.2"

test/test_basics.jl

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ using BlockSparseArrays:
3434
sparsemortar,
3535
view!
3636
using GPUArraysCore: @allowscalar
37-
using JLArrays: JLArray
37+
using JLArrays: JLArray, JLMatrix
3838
using LinearAlgebra: Adjoint, Transpose, dot, mul!, norm
3939
using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK, storedlength
4040
using TensorAlgebra: contract
@@ -315,6 +315,27 @@ arrayts = (Array, JLArray)
315315
@test @views(at[Block(1, 2)]) isa Adjoint
316316
end
317317
end
318+
@testset "adapt" begin
319+
a = BlockSparseArray{elt}(undef, [2, 2], [2, 2])
320+
a_12 = randn(elt, 2, 2)
321+
a[Block(1, 2)] = a_12
322+
a_jl = adapt(JLArray, a)
323+
@test a_jl isa BlockSparseMatrix{elt,JLMatrix{elt}}
324+
@test blocktype(a_jl) == JLMatrix{elt}
325+
@test blockstoredlength(a_jl) == 1
326+
@test a_jl[Block(1, 2)] isa JLMatrix{elt}
327+
@test adapt(Array, a_jl[Block(1, 2)]) == a_12
328+
329+
a = BlockSparseArray{elt}(undef, [2, 2], [2, 2])
330+
a_12 = randn(elt, 2, 2)
331+
a[Block(1, 2)] = a_12
332+
a_jl = adapt(JLArray, @view(a[:, :]))
333+
@test a_jl isa SubArray{elt,2,<:BlockSparseMatrix{elt,JLMatrix{elt}}}
334+
@test blocktype(a_jl) == JLMatrix{elt}
335+
@test blockstoredlength(a_jl) == 1
336+
@test a_jl[Block(1, 2)] isa JLMatrix{elt}
337+
@test adapt(Array, a_jl[Block(1, 2)]) == a_12
338+
end
318339
@testset "Tensor algebra" begin
319340
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
320341
@views for b in [Block(1, 2), Block(2, 1)]
@@ -1158,15 +1179,19 @@ arrayts = (Array, JLArray)
11581179
# Not testing other element types since they change the
11591180
# spacing so it isn't easy to make the test general.
11601181

1161-
a = BlockSparseMatrix{elt,arrayt{elt,2}}(undef, [2, 2], [2, 2])
1162-
@allowscalar a[1, 2] = 12
1163-
@test sprint(show, "text/plain", a) ==
1164-
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ │ ⋅ ⋅ \n ⋅ ⋅ │ ⋅ ⋅ "
1182+
a′ = BlockSparseMatrix{elt,arrayt{elt,2}}(undef, [2, 2], [2, 2])
1183+
@allowscalar a′[1, 2] = 12
1184+
for a in (a′, @view(a′[:, :]))
1185+
@test sprint(show, "text/plain", a) ==
1186+
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ │ ⋅ ⋅ \n ⋅ ⋅ │ ⋅ ⋅ "
1187+
end
11651188

1166-
a = BlockSparseArray{elt,3,arrayt{elt,3}}(undef, [2, 2], [2, 2], [2, 2])
1167-
@allowscalar a[1, 2, 1] = 121
1168-
@test sprint(show, "text/plain", a) ==
1169-
"$(summary(a)):\n[:, :, 1] =\n $(zero(eltype(a))) $(eltype(a)(121)) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 2] =\n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 3] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 4] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ "
1189+
a′ = BlockSparseArray{elt,3,arrayt{elt,3}}(undef, [2, 2], [2, 2], [2, 2])
1190+
@allowscalar a′[1, 2, 1] = 121
1191+
for a in (a′, @view(a′[:, :, :]))
1192+
@test sprint(show, "text/plain", a) ==
1193+
"$(summary(a)):\n[:, :, 1] =\n $(zero(eltype(a))) $(eltype(a)(121)) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 2] =\n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 3] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 4] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ "
1194+
end
11701195
end
11711196
end
11721197
@testset "TypeParameterAccessors.position" begin

0 commit comments

Comments
 (0)