Skip to content

Commit bf2a7a8

Browse files
authored
Merge branch 'main' into graded_svd
2 parents 76ce221 + 6c05eb9 commit bf2a7a8

File tree

8 files changed

+202
-24
lines changed

8 files changed

+202
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.2"
4+
version = "0.5.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ using SparseArraysBase:
3030

3131
# A return type for `blocks(array)` when `array` isn't blocked.
3232
# Represents a vector with just that single block.
33-
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
33+
struct SingleBlockView{N,Array<:AbstractArray{<:Any,N}} <: AbstractArray{Array,N}
3434
array::Array
3535
end
3636
Base.parent(a::SingleBlockView) = a.array
37+
Base.size(a::SingleBlockView) = ntuple(Returns(1), ndims(a))
3738
blocks_maybe_single(a) = blocks(a)
3839
blocks_maybe_single(a::Array) = SingleBlockView(a)
39-
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
40+
function Base.getindex(a::SingleBlockView{N}, index::Vararg{Int,N}) where {N}
4041
@assert all(isone, index)
4142
return parent(a)
4243
end
@@ -289,13 +290,6 @@ function blockrange(axis::AbstractUnitRange, r::Int)
289290
return error("Slicing with integer values isn't supported.")
290291
end
291292

292-
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
293-
for b in r
294-
@assert b blockaxes(axis, 1)
295-
end
296-
return r
297-
end
298-
299293
# This handles changing the blocking, for example:
300294
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
301295
# I = blockedrange([4, 4])
@@ -314,13 +308,20 @@ end
314308
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
315309
# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
316310
# a[I, I]
317-
function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}})
311+
function blockrange(axis::AbstractUnitRange, r::AbstractBlockVector{<:Block{1}})
318312
for b in r
319313
@assert b blockaxes(axis, 1)
320314
end
321315
return only(blockaxes(r))
322316
end
323317

318+
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
319+
for b in r
320+
@assert b blockaxes(axis, 1)
321+
end
322+
return r
323+
end
324+
324325
using BlockArrays: BlockSlice
325326
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
326327
return blockrange(axis, r.block)
@@ -357,7 +358,11 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
357358
end
358359

359360
function blockrange(axis::AbstractUnitRange, r::NonBlockedVector)
360-
return Block(1):Block(1)
361+
return Block.(Base.OneTo(1))
362+
end
363+
364+
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Integer})
365+
return Block.(Base.OneTo(1))
361366
end
362367

363368
function blockrange(axis::AbstractUnitRange, r)

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ include("abstractblocksparsearray/abstractblocksparsearray.jl")
2727
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
2828
include("abstractblocksparsearray/abstractblocksparsevector.jl")
2929
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
30+
include("abstractblocksparsearray/unblockedsubarray.jl")
3031
include("abstractblocksparsearray/views.jl")
3132
include("abstractblocksparsearray/arraylayouts.jl")
3233
include("abstractblocksparsearray/sparsearrayinterface.jl")

src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,20 @@ function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, ax
4343
return a_dest
4444
end
4545

46+
function _similar(arraytype::Type{<:AbstractArray}, size::Tuple)
47+
return similar(arraytype, size)
48+
end
49+
function _similar(
50+
::Type{<:SubArray{<:Any,<:Any,<:ArrayType}}, size::Tuple
51+
) where {ArrayType}
52+
return similar(ArrayType, size)
53+
end
54+
4655
# Materialize a SubArray view.
4756
function ArrayLayouts.sub_materialize(
4857
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
4958
)
50-
a_dest = blocktype(a)(undef, length.(axes))
59+
a_dest = _similar(blocktype(a), length.(axes))
5160
a_dest .= a
5261
return a_dest
5362
end
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using ArrayLayouts: ArrayLayouts, MemoryLayout
2+
using Base.Broadcast: Broadcast, BroadcastStyle
3+
using BlockArrays: BlockArrays, Block, BlockIndexRange, BlockSlice
4+
using TypeParameterAccessors: TypeParameterAccessors, parenttype, similartype
5+
6+
const UnblockedIndices = Union{
7+
Vector{<:Integer},BlockSlice{<:Block{1}},BlockSlice{<:BlockIndexRange{1}}
8+
}
9+
10+
const UnblockedSubArray{T,N} = SubArray{
11+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{UnblockedIndices}}
12+
}
13+
14+
function BlockArrays.blocks(a::UnblockedSubArray)
15+
return SingleBlockView(a)
16+
end
17+
18+
function DerivableInterfaces.interface(arraytype::Type{<:UnblockedSubArray})
19+
return interface(blocktype(parenttype(arraytype)))
20+
end
21+
22+
function ArrayLayouts.MemoryLayout(arraytype::Type{<:UnblockedSubArray})
23+
return MemoryLayout(blocktype(parenttype(arraytype)))
24+
end
25+
26+
function Broadcast.BroadcastStyle(arraytype::Type{<:UnblockedSubArray})
27+
return BroadcastStyle(blocktype(parenttype(arraytype)))
28+
end
29+
30+
function TypeParameterAccessors.similartype(arraytype::Type{<:UnblockedSubArray}, elt::Type)
31+
return similartype(blocktype(parenttype(arraytype)), elt)
32+
end
33+
34+
function Base.similar(
35+
a::UnblockedSubArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
36+
)
37+
return similar(similartype(blocktype(parenttype(a)), elt), axes)
38+
end
39+
function Base.similar(a::UnblockedSubArray, elt::Type, size::Tuple{Int,Vararg{Int}})
40+
return similar(a, elt, Base.OneTo.(size))
41+
end
42+
43+
function ArrayLayouts.sub_materialize(a::UnblockedSubArray)
44+
a_cpu = adapt(Array, a)
45+
a_cpu′ = similar(a_cpu)
46+
a_cpu′ .= a_cpu
47+
if typeof(a) === typeof(a_cpu)
48+
return a_cpu′
49+
end
50+
a′ = similar(a)
51+
a′ .= a_cpu′
52+
return a′
53+
end
54+
55+
function Base.map!(
56+
f, a_dest::AbstractArray, a_src1::UnblockedSubArray, a_src_rest::UnblockedSubArray...
57+
)
58+
return invoke(
59+
map!,
60+
Tuple{Any,AbstractArray,AbstractArray,Vararg{AbstractArray}},
61+
f,
62+
a_dest,
63+
a_src1,
64+
a_src_rest...,
65+
)
66+
end
67+
68+
# Fix ambiguity and scalar indexing errors with GPUArrays.
69+
using Adapt: adapt
70+
using GPUArraysCore: GPUArraysCore
71+
function Base.map!(
72+
f,
73+
a_dest::GPUArraysCore.AnyGPUArray,
74+
a_src1::UnblockedSubArray,
75+
a_src_rest::UnblockedSubArray...,
76+
)
77+
a_dest_cpu = adapt(Array, a_dest)
78+
a_srcs_cpu = map(adapt(Array), (a_src1, a_src_rest...))
79+
map!(f, a_dest_cpu, a_srcs_cpu...)
80+
a_dest .= a_dest_cpu
81+
return a_dest
82+
end
83+
84+
function Base.iszero(a::UnblockedSubArray)
85+
return invoke(iszero, Tuple{AbstractArray}, adapt(Array, a))
86+
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,19 @@ end
364364
function Base.size(a::SparseSubArrayBlocks)
365365
return length.(axes(a))
366366
end
367-
# TODO: Define `isstored`.
367+
368+
# TODO: Make a faster version for when the slice is blockwise.
369+
function SparseArraysBase.isstored(
370+
a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}
371+
) where {N}
372+
J = Base.reindex(parentindices(a.array), to_indices(a.array, Block.(I)))
373+
# TODO: Try doing this blockwise when possible rather
374+
# than elementwise.
375+
return any(Iterators.product(J...)) do K
376+
return isstored(parent(a.array), K...)
377+
end
378+
end
379+
368380
# TODO: Define `getstoredindex`, `getunstoredindex` instead.
369381
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
370382
# TODO: Should this be defined as `@view a.array[Block(I)]` instead?
@@ -400,9 +412,17 @@ function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) whe
400412
# TODO: Implement this properly.
401413
return true
402414
end
403-
function SparseArraysBase.eachstoredindex(a::SparseSubArrayBlocks)
404-
return eachstoredindex(view(blocks(parent(a.array)), blockrange(a)...))
415+
416+
function SparseArraysBase.eachstoredindex(::IndexCartesian, a::SparseSubArrayBlocks)
417+
return filter(eachindex(a)) do I
418+
return isstored(a, I)
419+
end
420+
421+
## # TODO: This only works for blockwise slices, i.e. slices using
422+
## # `BlockSliceCollection`.
423+
## return eachstoredindex(view(blocks(parent(a.array)), blockrange(a)...))
405424
end
425+
406426
# TODO: Either make this the generic interface or define
407427
# `SparseArraysBase.sparse_storage`, which is used
408428
# to defined this.
@@ -425,7 +445,7 @@ end
425445

426446
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)
427447
to_blocks_indices(I::BlockIndices{<:Vector{<:Block{1}}}) = Int.(I.blocks)
428-
to_blocks_indices(I::Base.Slice{<:BlockedOneTo}) = Base.OneTo(blocklength(I.indices))
448+
to_blocks_indices(I::Base.Slice) = Base.OneTo(blocklength(I.indices))
429449

430450
@interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(
431451
a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{BlockSliceCollection}}}

src/blocksparsearrayinterface/map.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,52 @@
1+
using BlockArrays: BlockRange, blockisequal
12
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
23
using GPUArraysCore: @allowscalar
34

5+
# Check if the block structures are the same.
6+
function same_block_structure(as::AbstractArray...)
7+
isempty(as) && return true
8+
return all(
9+
ntuple(ndims(first(as))) do dim
10+
ax = map(Base.Fix2(axes, dim), as)
11+
return blockisequal(ax...)
12+
end,
13+
)
14+
end
15+
16+
# Find the common stored blocks, assuming the block structures are the same.
17+
function union_eachblockstoredindex(as::AbstractArray...)
18+
return (map(eachblockstoredindex, as)...)
19+
end
20+
21+
function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
22+
# TODO: This assumes element types are numbers, generalize this logic.
23+
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
24+
Is = if f_preserves_zeros
25+
union_eachblockstoredindex(a_dest, a_srcs...)
26+
else
27+
BlockRange(a_dest)
28+
end
29+
for I in Is
30+
# TODO: Use:
31+
# block_dest = @view a_dest[I]
32+
# or:
33+
# block_dest = @view! a_dest[I]
34+
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(I))...]
35+
# TODO: Use:
36+
# block_srcs = map(a_src -> @view(a_src[I]), a_srcs)
37+
block_srcs = map(a_srcs) do a_src
38+
return blocks_maybe_single(a_src)[Int.(Tuple(I))...]
39+
end
40+
# TODO: Use `map!!` to handle immutable blocks.
41+
map!(f, block_dest, block_srcs...)
42+
# Replace the entire block, handles initializing new blocks
43+
# or if blocks are immutable.
44+
# TODO: Use `a_dest[I] = block_dest`.
45+
blocks(a_dest)[Int.(Tuple(I))...] = block_dest
46+
end
47+
return a_dest
48+
end
49+
450
# TODO: Rewrite this so that it takes the blocking structure
551
# made by combining the blocking of the axes (i.e. the blocking that
652
# is used to determine `union_stored_blocked_cartesianindices(...)`).
@@ -16,6 +62,10 @@ using GPUArraysCore: @allowscalar
1662
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
1763
return a_dest
1864
end
65+
if same_block_structure(a_dest, a_srcs...)
66+
map_blockwise!(f, a_dest, a_srcs...)
67+
return a_dest
68+
end
1969
# TODO: This assumes element types are numbers, generalize this logic.
2070
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
2171
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)

test/test_basics.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ arrayts = (Array, JLArray)
5151
a[Block(2, 2)] = dev(randn(elt, 3, 3))
5252
@test_broken a[:, 4]
5353

54-
# TODO: Fix this and turn it into a proper test.
55-
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
56-
a[Block(1, 1)] = dev(randn(elt, 2, 2))
57-
a[Block(2, 2)] = dev(randn(elt, 3, 3))
58-
@test_broken a[:, [2, 4]]
59-
@test_broken a[[3, 5], [2, 4]]
60-
6154
# TODO: Fix this and turn it into a proper test.
6255
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
6356
a[Block(1, 1)] = dev(randn(elt, 2, 2))
@@ -713,6 +706,20 @@ arrayts = (Array, JLArray)
713706
@test a[Block(2, 2)[1:2, 2:3]] == b
714707
@test blockstoredlength(a) == 1
715708

709+
# Noncontiguous slicing.
710+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
711+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
712+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
713+
I = ([3, 5], [2, 4])
714+
@test Array(a[I...]) == Array(a)[I...]
715+
716+
# Noncontiguous slicing.
717+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
718+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
719+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
720+
I = (:, [2, 4])
721+
@test Array(a[I...]) == Array(a)[I...]
722+
716723
a = BlockSparseArray{elt}(undef, [2, 3], [2, 3])
717724
@views for b in [Block(1, 1), Block(2, 2)]
718725
a[b] = randn(elt, size(a[b]))

0 commit comments

Comments
 (0)