Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"

[extensions]
BlockSparseArraysGradedUnitRangesExt = "GradedUnitRanges"
BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"]

[compat]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module BlockSparseArraysGradedUnitRangesExt

using BlockSparseArrays: AnyAbstractBlockSparseArray, BlockSparseArray, blocktype
using GradedUnitRanges: AbstractGradedUnitRange
using TypeParameterAccessors: set_ndims, unwrap_array_type

# A block spare array similar to the input (dense) array.
# TODO: Make `BlockSparseArrays.blocksparse_similar` more general and use that,
# and also turn it into an DerivableInterfaces.jl-based interface function.
function similar_blocksparse(
a::AbstractArray,
elt::Type,
axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}},
)
# TODO: Probably need to unwrap the type of `a` in certain cases
# to make a proper block type.
return BlockSparseArray{
elt,length(axes),set_ndims(unwrap_array_type(blocktype(a)), length(axes))
}(
axes
)
end

function Base.similar(

Check warning on line 24 in ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl#L24

Added line #L24 was not covered by tests
a::AbstractArray,
elt::Type,
axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}},
)
return similar_blocksparse(a, elt, axes)

Check warning on line 29 in ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl#L29

Added line #L29 was not covered by tests
end

# Fix ambiguity error with `BlockArrays.jl`.
function Base.similar(
a::StridedArray,
elt::Type,
axes::Tuple{
AbstractGradedUnitRange,AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}
},
)
return similar_blocksparse(a, elt, axes)
end

# Fix ambiguity error with `BlockSparseArrays.jl`.
function Base.similar(
a::AnyAbstractBlockSparseArray,
elt::Type,
axes::Tuple{
AbstractGradedUnitRange,AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}
},
)
return similar_blocksparse(a, elt, axes)
end

function getindex_blocksparse(a::AbstractArray, I::AbstractUnitRange...)
a′ = similar(a, only.(axes.(I))...)
a′ .= a
return a′
end

function Base.getindex(
a::AbstractArray, I1::AbstractGradedUnitRange, I_rest::AbstractGradedUnitRange...
)
return getindex_blocksparse(a, I1, I_rest...)
end

# Fix ambiguity errors.
function Base.getindex(

Check warning on line 67 in ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl#L67

Added line #L67 was not covered by tests
a::AnyAbstractBlockSparseArray,
I1::AbstractGradedUnitRange,
I_rest::AbstractGradedUnitRange...,
)
return getindex_blocksparse(a, I1, I_rest...)

Check warning on line 72 in ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl#L72

Added line #L72 was not covered by tests
end
function Base.getindex(
a::AnyAbstractBlockSparseArray{<:Any,2},
I1::AbstractGradedUnitRange,
I2::AbstractGradedUnitRange,
)
return getindex_blocksparse(a, I1, I2)
end

end
20 changes: 20 additions & 0 deletions src/abstractblocksparsearray/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,23 @@ function Broadcast.BroadcastStyle(
)
return BlockSparseArrayStyle{ndims(arraytype)}()
end

# These catch cases that aren't caught by the standard
# `BlockSparseArrayStyle` definition, and also fix
# ambiguity issues.
function Base.copyto!(dest::AnyAbstractBlockSparseArray, bc::Broadcasted)
copyto_blocksparse!(dest, bc)
return dest
end
function Base.copyto!(
dest::AnyAbstractBlockSparseArray, bc::Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}
)
copyto_blocksparse!(dest, bc)
return dest
end
function Base.copyto!(
dest::AnyAbstractBlockSparseArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
) where {N}
copyto_blocksparse!(dest, bc)
return dest
end
7 changes: 5 additions & 2 deletions src/abstractblocksparsearray/cat.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using DerivableInterfaces: @interface, interface

# TODO: Define with `@derive`.
function Base.cat(as::AnyAbstractBlockSparseArray...; dims)
function Base._cat(dims, as::AnyAbstractBlockSparseArray...)
# TODO: Call `DerivableInterfaces.cat_along(dims, as...)` instead,
# for better inferability. See:
# https://github.com/ITensor/DerivableInterfaces.jl/pull/13
# https://github.com/ITensor/DerivableInterfaces.jl/pull/17
return @interface interface(as...) cat(as...; dims)
end
96 changes: 11 additions & 85 deletions src/abstractblocksparsearray/map.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,8 @@
using ArrayLayouts: LayoutArray
using BlockArrays: blockisequal
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
using GPUArraysCore: @allowscalar
using BlockArrays: AbstractBlockVector, Block
using LinearAlgebra: Adjoint, Transpose
using SparseArraysBase: SparseArraysBase, SparseArrayStyle

# Returns `Vector{<:CartesianIndices}`
function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray})
combined_axes = combine_axes(axes.(as)...)
stored_blocked_cartesianindices_as = map(as) do a
return blocked_cartesianindices(axes(a), combined_axes, eachblockstoredindex(a))
end
return ∪(stored_blocked_cartesianindices_as...)
end

# This is used by `map` to get the output axes.
# This is type piracy, try to avoid this, maybe requires defining `map`.
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)

reblock(a) = a

# TODO: Make this more general, independent of `AbstractBlockSparseArray`.
# If the blocking of the slice doesn't match the blocking of the
# parent array, reblock according to the blocking of the parent array.
function reblock(
Expand All @@ -32,12 +15,14 @@
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
end

# TODO: Make this more general, independent of `AbstractBlockSparseArray`.
function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
)
return @view parent(a)[map(I -> I.array, parentindices(a))...]
end

# TODO: Make this more general, independent of `AbstractBlockSparseArray`.
function reblock(
a::SubArray{
<:Any,
Expand All @@ -50,77 +35,18 @@
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
end

# `map!` specialized to zero-dimensional inputs.
function map_zero_dim! end

@interface ::AbstractArrayInterface function map_zero_dim!(
f, a_dest::AbstractArray, a_srcs::AbstractArray...
)
@allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...)
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)

Check warning on line 39 in src/abstractblocksparsearray/map.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/map.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
return a_dest
end

# TODO: Move to `blocksparsearrayinterface/map.jl`.
# TODO: Rewrite this so that it takes the blocking structure
# made by combining the blocking of the axes (i.e. the blocking that
# is used to determine `union_stored_blocked_cartesianindices(...)`).
# `reblock` is a partial solution to that, but a bit ad-hoc.
## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function.
@interface interface::AbstractBlockSparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, a_srcs::AbstractArray...
)
if iszero(ndims(a_dest))
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
return a_dest
end

a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
BI_dest = blockindexrange(a_dest, I)
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
# TODO: Use `map!!` to handle immutable blocks.
map!(f, subblock_dest, subblock_srcs...)
# Replace the entire block, handles initializing new blocks
# or if blocks are immutable.
blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest
end
function Base.map!(f, a_dest::AnyAbstractBlockSparseArray, a_srcs::AbstractArray...)
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)

Check warning on line 43 in src/abstractblocksparsearray/map.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/map.jl#L42-L43

Added lines #L42 - L43 were not covered by tests
return a_dest
end

# TODO: Move to `blocksparsearrayinterface/map.jl`.
@interface ::AbstractBlockSparseArrayInterface function Base.mapreduce(
f, op, as::AbstractArray...; kwargs...
function Base.map!(

Check warning on line 46 in src/abstractblocksparsearray/map.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/map.jl#L46

Added line #L46 was not covered by tests
f, a_dest::AnyAbstractBlockSparseArray, a_srcs::AnyAbstractBlockSparseArray...
)
# TODO: Define an `init` value based on the element type.
return @interface interface(blocks.(as)...) mapreduce(
block -> mapreduce(f, op, block), op, blocks.(as)...; kwargs...
)
end

# TODO: Move to `blocksparsearrayinterface/map.jl`.
@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray)
# TODO: Just call `iszero(blocks(a))`?
return @interface interface(blocks(a)) iszero(blocks(a))
end

# TODO: Move to `blocksparsearrayinterface/map.jl`.
@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray)
# TODO: Just call `isreal(blocks(a))`?
return @interface interface(blocks(a)) isreal(blocks(a))
end

function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...)
@interface interface(a_srcs...) map!(f, a_dest, a_srcs...)
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)

Check warning on line 49 in src/abstractblocksparsearray/map.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/map.jl#L49

Added line #L49 was not covered by tests
return a_dest
end

Expand Down
26 changes: 26 additions & 0 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BlockedVector,
block,
blockcheckbounds,
blockisequal,
blocklengths,
blocks,
findblockindex
Expand Down Expand Up @@ -40,6 +41,31 @@
return storedvalues(blocks(a))
end

# TODO: Generalize this, this catches simple cases
# where the more general definition isn't specific enough.
blocktype(a::Array) = typeof(a)
# TODO: Maybe unwrap SubArrays?
function blocktype(a::AbstractArray)

Check warning on line 48 in src/blocksparsearrayinterface/blocksparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksparsearrayinterface/blocksparsearrayinterface.jl#L48

Added line #L48 was not covered by tests
# TODO: Unfortunately, this doesn't always give
# a concrete type, even when it could be concrete, i.e.
#=
```julia
julia> eltype(blocks(BlockArray(randn(2, 2), [1, 1], [1, 1])))
Matrix{Float64} (alias for Array{Float64, 2})

julia> eltype(blocks(BlockedArray(randn(2, 2), [1, 1], [1, 1])))
AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2})

julia> eltype(blocks(randn(2, 2)))
AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2})
```
=#
if isempty(blocks(a))
return eltype(blocks(a))

Check warning on line 64 in src/blocksparsearrayinterface/blocksparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksparsearrayinterface/blocksparsearrayinterface.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
end
return eltype(first(blocks(a)))

Check warning on line 66 in src/blocksparsearrayinterface/blocksparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/blocksparsearrayinterface/blocksparsearrayinterface.jl#L66

Added line #L66 was not covered by tests
end

abstract type AbstractBlockSparseArrayInterface <: AbstractSparseArrayInterface end

# TODO: Also support specifying the `blocktype` along with the `eltype`.
Expand Down
23 changes: 19 additions & 4 deletions src/blocksparsearrayinterface/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
using GPUArraysCore: @allowscalar
using MapBroadcast: Mapped
using DerivableInterfaces: DerivableInterfaces, @interface

Expand Down Expand Up @@ -36,14 +37,28 @@ function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type)
return similar(first(m.args), elt, combine_axes(axes.(m.args)...))
end

# Catches cases like `dest .= value` or `dest .= value1 .+ value2`.
# If the RHS is zero, this makes sure that the storage is emptied,
# which is logic that is handled by `fill!`.
function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}})
# `[]` is used to unwrap zero-dimensional arrays.
value = @allowscalar bc.f(bc.args...)[]
return @interface BlockSparseArrayInterface() fill!(dest, value)
end

# Broadcasting implementation
# TODO: Delete this in favor of `DerivableInterfaces` version.
function Base.copyto!(
dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
) where {N}
function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted)
# convert to map
# flatten and only keep the AbstractArray arguments
m = Mapped(bc)
@interface interface(bc) map!(m.f, dest, m.args...)
@interface interface(dest, bc) map!(m.f, dest, m.args...)
return dest
end

function Base.copyto!(
dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
) where {N}
copyto_blocksparse!(dest, bc)
return dest
end
Loading