diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index 3f54093e..fd90bcf7 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -2,7 +2,7 @@ module BlockSparseArrays include("BlockArraysExtensions/BlockArraysExtensions.jl") include("blocksparsearrayinterface/blocksparsearrayinterface.jl") include("blocksparsearrayinterface/linearalgebra.jl") -include("blocksparsearrayinterface/blockzero.jl") +include("blocksparsearrayinterface/getunstoredblock.jl") include("blocksparsearrayinterface/broadcast.jl") include("blocksparsearrayinterface/map.jl") include("blocksparsearrayinterface/arraylayouts.jl") diff --git a/src/blocksparsearray/defaults.jl b/src/blocksparsearray/defaults.jl index 938610d3..788a9750 100644 --- a/src/blocksparsearray/defaults.jl +++ b/src/blocksparsearray/defaults.jl @@ -38,5 +38,5 @@ function default_blocks( block_data::Dictionary{<:CartesianIndex{N},<:AbstractArray{<:Any,N}}, axes::Tuple{Vararg{AbstractUnitRange,N}}, ) where {N} - return SparseArrayDOK(block_data, blocklength.(axes), BlockZero(axes)) + return SparseArrayDOK(block_data, blocklength.(axes), GetUnstoredBlock(axes)) end diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index b4583c6f..9872ca1f 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -65,7 +65,7 @@ end ) # TODO: Use `Block()[]` once https://github.com/JuliaArrays/BlockArrays.jl/issues/430 # is fixed. - return a[BlockIndex{0,Tuple{},Tuple{}}((), ())] + return a[BlockIndex()] end # a[1:2, 1:2] @@ -135,7 +135,7 @@ end ) # TODO: Use `Block()[]` once https://github.com/JuliaArrays/BlockArrays.jl/issues/430 # is fixed. - a[BlockIndex{0,Tuple{},Tuple{}}((), ())] = value + a[BlockIndex()] = value return a end @@ -301,6 +301,8 @@ end function Base.size(a::SparseSubArrayBlocks) return length.(axes(a)) end +# TODO: Define `isstored`. +# TODO: Define `getstoredindex`, `getunstoredindex` instead. function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} # TODO: Should this be defined as `@view a.array[Block(I)]` instead? return @view a.array[Block(I)] @@ -312,9 +314,11 @@ function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where ## return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...] end # TODO: This should be handled by generic `AbstractSparseArray` code. +# TODO: Define `getstoredindex`, `getunstoredindex` instead. function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N} return a[Tuple(I)...] end +# TODO: Define `setstoredindex!`, `setunstoredindex!` instead. function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N} parent_blocks = @view blocks(parent(a.array))[blockrange(a)...] # TODO: The following line is required to instantiate @@ -345,18 +349,11 @@ SparseArraysBase.storedlength(a::SparseSubArrayBlocks) = length(eachstoredindex( ## array::Array ## end -## TODO: Delete. +## TODO: Define `storedvalues` instead. ## function SparseArraysBase.sparse_storage(a::SparseSubArrayBlocks) ## return map(I -> a[I], eachstoredindex(a)) ## end -## TODO: Delete. -## function SparseArraysBase.getindex_zero_function(a::SparseSubArrayBlocks) -## # TODO: Base it off of `getindex_zero_function(blocks(parent(a.array))`, but replace the -## # axes with `axes(a.array)`. -## return BlockZero(axes(a.array)) -## end - function SparseArraysBase.getunstoredindex( a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N} ) where {N} diff --git a/src/blocksparsearrayinterface/blockzero.jl b/src/blocksparsearrayinterface/blockzero.jl deleted file mode 100644 index 286823f7..00000000 --- a/src/blocksparsearrayinterface/blockzero.jl +++ /dev/null @@ -1,49 +0,0 @@ -using BlockArrays: Block, blockedrange - -# Extensions to BlockArrays.jl -blocktuple(b::Block) = Block.(b.n) -inttuple(b::Block) = b.n - -# The size of a block -function block_size(axes::Tuple{Vararg{AbstractUnitRange}}, block::Block) - return length.(getindex.(axes, blocktuple(block))) -end - -# The size of a block -function block_size(blockinds::Tuple{Vararg{AbstractVector}}, block::Block) - return block_size(blockedrange.(blockinds), block) -end - -struct BlockZero{Axes} - axes::Axes -end - -function (f::BlockZero)(a::AbstractArray, I...) - return f(eltype(a), I...) -end - -function (f::BlockZero)(arraytype::Type{<:SubArray{<:Any,<:Any,P}}, I...) where {P} - return f(P, I...) -end - -function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I::CartesianIndex) - return f(arraytype, Tuple(I)...) -end - -function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I::Int...) - # TODO: Make sure this works for sparse or block sparse blocks, immutable - # blocks, diagonal blocks, etc.! - blck_size = block_size(f.axes, Block(I)) - blck_type = similartype(arraytype, blck_size) - return fill!(blck_type(undef, blck_size), false) -end - -# Fallback so that `SparseArray` with scalar elements works. -function (f::BlockZero)(blocktype::Type{<:Number}, I...) - return zero(blocktype) -end - -# Fallback to Array if it is abstract -function (f::BlockZero)(arraytype::Type{AbstractArray{T,N}}, I...) where {T,N} - return f(Array{T,N}, I...) -end diff --git a/src/blocksparsearrayinterface/getunstoredblock.jl b/src/blocksparsearrayinterface/getunstoredblock.jl new file mode 100644 index 00000000..db9267e5 --- /dev/null +++ b/src/blocksparsearrayinterface/getunstoredblock.jl @@ -0,0 +1,25 @@ +using ArrayLayouts: zero! +using BlockArrays: Block + +struct GetUnstoredBlock{Axes} + axes::Axes +end + +@inline function (f::GetUnstoredBlock)( + a::AbstractArray{<:Any,N}, I::Vararg{Int,N} +) where {N} + # TODO: Make sure this works for sparse or block sparse blocks, immutable + # blocks, diagonal blocks, etc.! + b_size = ntuple(ndims(a)) do d + return length(f.axes[d][Block(I[d])]) + end + b = similar(eltype(a), b_size) + zero!(b) + return b +end +# TODO: Use `Base.to_indices`. +@inline function (f::GetUnstoredBlock)( + a::AbstractArray{<:Any,N}, I::CartesianIndex{N} +) where {N} + return f(a, Tuple(I)...) +end