Skip to content

Commit 13a3198

Browse files
committed
Add cat implementation
1 parent c98436b commit 13a3198

File tree

7 files changed

+107
-0
lines changed

7 files changed

+107
-0
lines changed

NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ include("blocksparsearrayinterface/broadcast.jl")
77
include("blocksparsearrayinterface/map.jl")
88
include("blocksparsearrayinterface/arraylayouts.jl")
99
include("blocksparsearrayinterface/views.jl")
10+
include("blocksparsearrayinterface/cat.jl")
1011
include("abstractblocksparsearray/abstractblocksparsearray.jl")
1112
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
1213
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
@@ -17,6 +18,7 @@ include("abstractblocksparsearray/sparsearrayinterface.jl")
1718
include("abstractblocksparsearray/broadcast.jl")
1819
include("abstractblocksparsearray/map.jl")
1920
include("abstractblocksparsearray/linearalgebra.jl")
21+
include("abstractblocksparsearray/cat.jl")
2022
include("blocksparsearray/defaults.jl")
2123
include("blocksparsearray/blocksparsearray.jl")
2224
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# TODO: Change to `AnyAbstractBlockSparseArray`.
2+
function Base.cat(as::BlockSparseArrayLike...; dims)
3+
# TODO: Use `sparse_cat` instead, currently
4+
# that erroneously allocates too many blocks that are
5+
# zero and shouldn't be stored.
6+
return blocksparse_cat(as...; dims)
7+
end
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
2+
using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat!
3+
4+
# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`.
5+
# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`.
6+
function SparseArrayInterface.axis_cat(
7+
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
8+
)
9+
return blockedrange(vcat(blocklengths(a1), blocklengths(a2)))
10+
end
11+
12+
# that erroneously allocates too many blocks that are
13+
# zero and shouldn't be stored.
14+
function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
15+
sparse_cat!(blocks(a_dest), blocks.(as)...; dims)
16+
return a_dest
17+
end
18+
19+
# TODO: Delete this in favor of `sparse_cat`, currently
20+
# that erroneously allocates too many blocks that are
21+
# zero and shouldn't be stored.
22+
function blocksparse_cat(as::AbstractArray...; dims)
23+
a_dest = allocate_cat_output(as...; dims)
24+
blocksparse_cat!(a_dest, as...; dims)
25+
return a_dest
26+
end

NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include("sparsearrayinterface/broadcast.jl")
1212
include("sparsearrayinterface/conversion.jl")
1313
include("sparsearrayinterface/wrappers.jl")
1414
include("sparsearrayinterface/zero.jl")
15+
include("sparsearrayinterface/cat.jl")
1516
include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl")
1617
include("abstractsparsearray/abstractsparsearray.jl")
1718
include("abstractsparsearray/abstractsparsematrix.jl")
@@ -24,6 +25,7 @@ include("abstractsparsearray/broadcast.jl")
2425
include("abstractsparsearray/map.jl")
2526
include("abstractsparsearray/baseinterface.jl")
2627
include("abstractsparsearray/convert.jl")
28+
include("abstractsparsearray/cat.jl")
2729
include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl")
2830
include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl")
2931
end
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# TODO: Change to `AnyAbstractSparseArray`.
2+
function Base.cat(as::SparseArrayLike...; dims)
3+
return sparse_cat(as...; dims)
4+
end
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
unval(x) = x
2+
unval(::Val{x}) where {x} = x
3+
4+
# TODO: Assert that `a1` and `a2` start at one.
5+
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
6+
function axis_cat(
7+
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
8+
)
9+
return axis_cat(axis_cat(a1, a2), a_rest...)
10+
end
11+
function cat_axes(as::AbstractArray...; dims)
12+
return ntuple(length(first(axes.(as)))) do dim
13+
return if dim in unval(dims)
14+
axis_cat(map(axes -> axes[dim], axes.(as))...)
15+
else
16+
axes(first(as))[dim]
17+
end
18+
end
19+
end
20+
21+
function allocate_cat_output(as::AbstractArray...; dims)
22+
eltype_dest = promote_type(eltype.(as)...)
23+
axes_dest = cat_axes(as...; dims)
24+
# TODO: Promote the block types of the inputs rather than using
25+
# just the first input.
26+
# TODO: Make this customizable with `cat_similar`.
27+
# TODO: Base the zero element constructor on those of the inputs,
28+
# for example block sparse arrays.
29+
return similar(first(as), eltype_dest, axes_dest...)
30+
end
31+
32+
# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
33+
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
34+
# This is very similar to the `Base.cat` implementation but handles zero values better.
35+
function cat_offset!(
36+
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
37+
)
38+
inds = ntuple(ndims(a_dest)) do dim
39+
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)
40+
end
41+
a_dest[inds...] = a1
42+
new_offsets = ntuple(ndims(a_dest)) do dim
43+
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]
44+
end
45+
cat_offset!(a_dest, new_offsets, a_rest...; dims)
46+
return a_dest
47+
end
48+
function cat_offset!(a_dest::AbstractArray, offsets; dims)
49+
return a_dest
50+
end
51+
52+
# TODO: Define a generic `cat!` function.
53+
function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
54+
offsets = ntuple(zero, ndims(a_dest))
55+
# TODO: Fill `a_dest` with zeros if needed.
56+
cat_offset!(a_dest, offsets, as...; dims)
57+
return a_dest
58+
end
59+
60+
function sparse_cat(as::AbstractArray...; dims)
61+
a_dest = allocate_cat_output(as...; dims)
62+
sparse_cat!(a_dest, as...; dims)
63+
return a_dest
64+
end

NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex{1})
144144
end
145145

146146
# Slicing
147+
# TODO: Make this handle more general slicing operations,
148+
# base it off of `ArrayLayouts.sub_materialize`.
147149
function sparse_setindex!(a::AbstractArray, value, I::AbstractUnitRange...)
148150
inds = CartesianIndices(I)
149151
for i in stored_indices(value)

0 commit comments

Comments
 (0)