Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.1"
version = "0.5.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
11 changes: 8 additions & 3 deletions src/BlockArraysExtensions/BlockArraysExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ using SparseArraysBase:

# A return type for `blocks(array)` when `array` isn't blocked.
# Represents a vector with just that single block.
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
struct SingleBlockView{N,Array<:AbstractArray{<:Any,N}} <: AbstractArray{Array,N}
array::Array
end
Base.parent(a::SingleBlockView) = a.array
Base.size(a::SingleBlockView) = ntuple(Returns(1), ndims(a))
blocks_maybe_single(a) = blocks(a)
blocks_maybe_single(a::Array) = SingleBlockView(a)
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
function Base.getindex(a::SingleBlockView{N}, index::Vararg{Int,N}) where {N}
@assert all(isone, index)
return parent(a)
end
Expand Down Expand Up @@ -357,7 +358,11 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
end

function blockrange(axis::AbstractUnitRange, r::NonBlockedVector)
return Block(1):Block(1)
return Block.(Base.OneTo(1))
end

function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Integer})
return Block.(Base.OneTo(1))
end

function blockrange(axis::AbstractUnitRange, r)
Expand Down
1 change: 1 addition & 0 deletions src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
include("abstractblocksparsearray/abstractblocksparsevector.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/unblockedsubarray.jl")
include("abstractblocksparsearray/views.jl")
include("abstractblocksparsearray/arraylayouts.jl")
include("abstractblocksparsearray/sparsearrayinterface.jl")
Expand Down
11 changes: 10 additions & 1 deletion src/abstractblocksparsearray/arraylayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,20 @@ function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, ax
return a_dest
end

function _similar(arraytype::Type{<:AbstractArray}, size::Tuple)
return similar(arraytype, size)
end
function _similar(
::Type{<:SubArray{<:Any,<:Any,<:ArrayType}}, size::Tuple
) where {ArrayType}
return similar(ArrayType, size)
end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
)
a_dest = blocktype(a)(undef, length.(axes))
a_dest = _similar(blocktype(a), length.(axes))
a_dest .= a
return a_dest
end
86 changes: 86 additions & 0 deletions src/abstractblocksparsearray/unblockedsubarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using ArrayLayouts: ArrayLayouts, MemoryLayout
using Base.Broadcast: Broadcast, BroadcastStyle
using BlockArrays: BlockArrays, Block, BlockIndexRange, BlockSlice
using TypeParameterAccessors: TypeParameterAccessors, parenttype, similartype

const UnblockedIndices = Union{
Vector{<:Integer},BlockSlice{<:Block{1}},BlockSlice{<:BlockIndexRange{1}}
}

const UnblockedSubArray{T,N} = SubArray{
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{UnblockedIndices}}
}

function BlockArrays.blocks(a::UnblockedSubArray)
return SingleBlockView(a)
end

function DerivableInterfaces.interface(arraytype::Type{<:UnblockedSubArray})
return interface(blocktype(parenttype(arraytype)))
end

function ArrayLayouts.MemoryLayout(arraytype::Type{<:UnblockedSubArray})
return MemoryLayout(blocktype(parenttype(arraytype)))
end

function Broadcast.BroadcastStyle(arraytype::Type{<:UnblockedSubArray})
return BroadcastStyle(blocktype(parenttype(arraytype)))
end

function TypeParameterAccessors.similartype(arraytype::Type{<:UnblockedSubArray}, elt::Type)
return similartype(blocktype(parenttype(arraytype)), elt)
end

function Base.similar(
a::UnblockedSubArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
)
return similar(similartype(blocktype(parenttype(a)), elt), axes)
end
function Base.similar(a::UnblockedSubArray, elt::Type, size::Tuple{Int,Vararg{Int}})
return similar(a, elt, Base.OneTo.(size))
end

function ArrayLayouts.sub_materialize(a::UnblockedSubArray)
a_cpu = adapt(Array, a)
a_cpu′ = similar(a_cpu)
a_cpu′ .= a_cpu
if typeof(a) === typeof(a_cpu)
return a_cpu′
end
a′ = similar(a)
a′ .= a_cpu′
return a′
end

function Base.map!(
f, a_dest::AbstractArray, a_src1::UnblockedSubArray, a_src_rest::UnblockedSubArray...
)
return invoke(
map!,
Tuple{Any,AbstractArray,AbstractArray,Vararg{AbstractArray}},
f,
a_dest,
a_src1,
a_src_rest...,
)
end

# Fix ambiguity and scalar indexing errors with GPUArrays.
using Adapt: adapt
using GPUArraysCore: GPUArraysCore
function Base.map!(
f,
a_dest::GPUArraysCore.AnyGPUArray,
a_src1::UnblockedSubArray,
a_src_rest::UnblockedSubArray...,
)
a_dest_cpu = adapt(Array, a_dest)
a_srcs_cpu = map(adapt(Array), (a_src1, a_src_rest...))
map!(f, a_dest_cpu, a_srcs_cpu...)
a_dest .= a_dest_cpu
return a_dest
end

function Base.iszero(a::UnblockedSubArray)
return invoke(iszero, Tuple{AbstractArray}, adapt(Array, a))
end
26 changes: 23 additions & 3 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,19 @@ end
function Base.size(a::SparseSubArrayBlocks)
return length.(axes(a))
end
# TODO: Define `isstored`.

# TODO: Make a faster version for when the slice is blockwise.
function SparseArraysBase.isstored(
a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}
) where {N}
J = Base.reindex(parentindices(a.array), to_indices(a.array, Block.(I)))
# TODO: Try doing this blockwise when possible rather
# than elementwise.
return any(Iterators.product(J...)) do K
return isstored(parent(a.array), K...)
end
end

# TODO: Define `getstoredindex`, `getunstoredindex` instead.
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
# TODO: Should this be defined as `@view a.array[Block(I)]` instead?
Expand Down Expand Up @@ -400,9 +412,17 @@ function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) whe
# TODO: Implement this properly.
return true
end
function SparseArraysBase.eachstoredindex(a::SparseSubArrayBlocks)
return eachstoredindex(view(blocks(parent(a.array)), blockrange(a)...))

function SparseArraysBase.eachstoredindex(::IndexCartesian, a::SparseSubArrayBlocks)
return filter(eachindex(a)) do I
return isstored(a, I)
end

## # TODO: This only works for blockwise slices, i.e. slices using
## # `BlockSliceCollection`.
## return eachstoredindex(view(blocks(parent(a.array)), blockrange(a)...))
end

# TODO: Either make this the generic interface or define
# `SparseArraysBase.sparse_storage`, which is used
# to defined this.
Expand Down
21 changes: 14 additions & 7 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ arrayts = (Array, JLArray)
a[Block(2, 2)] = dev(randn(elt, 3, 3))
@test_broken a[:, 4]

# TODO: Fix this and turn it into a proper test.
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
a[Block(2, 2)] = dev(randn(elt, 3, 3))
@test_broken a[:, [2, 4]]
@test_broken a[[3, 5], [2, 4]]

# TODO: Fix this and turn it into a proper test.
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
Expand Down Expand Up @@ -713,6 +706,20 @@ arrayts = (Array, JLArray)
@test a[Block(2, 2)[1:2, 2:3]] == b
@test blockstoredlength(a) == 1

# Non-contiguous slicing.
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
a[Block(2, 2)] = dev(randn(elt, 3, 3))
I = ([3, 5], [2, 4])
@test Array(a[I...]) == Array(a)[I...]

# TODO: Fix this and turn it into a proper test.
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
a[Block(2, 2)] = dev(randn(elt, 3, 3))
I = (:, [2, 4])
@test Array(a[I...]) == Array(a)[I...]

a = BlockSparseArray{elt}(undef, [2, 3], [2, 3])
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
Expand Down
Loading