Skip to content

Commit 1701179

Browse files
committed
[BlockSparseArrys] Zero dimensional block sparse array
1 parent f9b6309 commit 1701179

File tree

5 files changed

+62
-1
lines changed

5 files changed

+62
-1
lines changed

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,16 @@ function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T,N}}) where {T,N
2222
return SparseArrayDOK{AbstractArray{T,N},N}
2323
end
2424

25-
## # Specialized in order to fix ambiguity error with `BlockArrays`.
25+
# Specialized in order to fix ambiguity error with `BlockArrays`.
2626
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
2727
return blocksparse_getindex(a, I...)
2828
end
2929

30+
# Specialized in order to fix ambiguity error with `BlockArrays`.
31+
function Base.getindex(a::AbstractBlockSparseArray{<:Any,0})
32+
return blocksparse_getindex(a)
33+
end
34+
3035
## # Fix ambiguity error with `BlockArrays`.
3136
## function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N}
3237
## return ArrayLayouts.layout_getindex(a, I)
@@ -51,6 +56,12 @@ function Base.setindex!(
5156
return a
5257
end
5358

59+
# Fix ambiguity error.
60+
function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value)
61+
blocksparse_setindex!(a, value)
62+
return a
63+
end
64+
5465
function Base.setindex!(
5566
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
5667
) where {N}

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ function Base.getindex(
9292
)
9393
return ArrayLayouts.layout_getindex(a, I...)
9494
end
95+
# Fixes ambiguity error.
96+
function Base.getindex(a::BlockSparseArrayLike{<:Any,0})
97+
return ArrayLayouts.layout_getindex(a)
98+
end
9599

96100
# TODO: Define `blocksparse_isassigned`.
97101
function Base.isassigned(
@@ -100,6 +104,11 @@ function Base.isassigned(
100104
return isassigned(blocks(a), Int.(index)...)
101105
end
102106

107+
# Fix ambiguity error.
108+
function Base.isassigned(a::BlockSparseArrayLike{<:Any,0})
109+
return isassigned(blocks(a))
110+
end
111+
103112
function Base.isassigned(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N}
104113
return isassigned(a, Tuple(index)...)
105114
end

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {
7272
return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes)
7373
end
7474

75+
function BlockSparseArray{T,0}(axes::Tuple{}) where {T}
76+
return BlockSparseArray{T,0,default_arraytype(T, axes)}(axes)
77+
end
78+
7579
function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N}
7680
return BlockSparseArray{T,N}(blockedrange.(dims))
7781
end
@@ -84,6 +88,10 @@ function BlockSparseArray{T}(axes::Tuple{Vararg{AbstractUnitRange}}) where {T}
8488
return BlockSparseArray{T,length(axes)}(axes)
8589
end
8690

91+
function BlockSparseArray{T}(axes::Tuple{}) where {T}
92+
return BlockSparseArray{T,length(axes)}(axes)
93+
end
94+
8795
function BlockSparseArray{T}(dims::Vararg{Vector{Int}}) where {T}
8896
return BlockSparseArray{T}(dims)
8997
end
@@ -92,6 +100,10 @@ function BlockSparseArray{T}(axes::Vararg{AbstractUnitRange}) where {T}
92100
return BlockSparseArray{T}(axes)
93101
end
94102

103+
function BlockSparseArray{T}() where {T}
104+
return BlockSparseArray{T}(())
105+
end
106+
95107
function BlockSparseArray{T,N,A}(
96108
::UndefInitializer, dims::Tuple
97109
) where {T,N,A<:AbstractArray{T,N}}

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where
2222
return a[findblockindex.(axes(a), I)...]
2323
end
2424

25+
# Fix ambiguity error.
26+
function blocksparse_getindex(a::AbstractArray{<:Any,0})
27+
# TODO: Use `Block()[]` once https://github.com/JuliaArrays/BlockArrays.jl/issues/430
28+
# is fixed.
29+
return a[BlockIndex{0,Tuple{},Tuple{}}((), ())]
30+
end
31+
2532
# a[1:2, 1:2]
2633
# TODO: This definition means that the result of slicing a block sparse array
2734
# with a non-blocked unit range is blocked. We may want to change that behavior,
@@ -77,6 +84,14 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N
7784
return a
7885
end
7986

87+
# Fix ambiguity error.
88+
function blocksparse_setindex!(a::AbstractArray{<:Any,0}, value)
89+
# TODO: Use `Block()[]` once https://github.com/JuliaArrays/BlockArrays.jl/issues/430
90+
# is fixed.
91+
a[BlockIndex{0,Tuple{},Tuple{}}((), ())] = value
92+
return a
93+
end
94+
8095
function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::BlockIndex{N}) where {N}
8196
i = Int.(Tuple(block(I)))
8297
a_b = blocks(a)[i...]
@@ -86,6 +101,15 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::BlockIndex{N
86101
return a
87102
end
88103

104+
# Fix ambiguity error.
105+
function blocksparse_setindex!(a::AbstractArray{<:Any,0}, value, I::BlockIndex{0})
106+
a_b = blocks(a)[]
107+
a_b[] = value
108+
# Set the block, required if it is structurally zero.
109+
blocks(a)[] = a_b
110+
return a
111+
end
112+
89113
function blocksparse_fill!(a::AbstractArray, value)
90114
for b in BlockRange(a)
91115
# We can't use:

NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ function sparse_getindex(a::AbstractArray, I::Vararg{Int})
8282
return sparse_getindex(a, CartesianIndex(I))
8383
end
8484

85+
# Fix ambiguity error.
86+
function sparse_getindex(a::AbstractArray{<:Any,0})
87+
return sparse_getindex(a, CartesianIndex())
88+
end
89+
8590
# Linear indexing
8691
function sparse_getindex(a::AbstractArray, I::CartesianIndex{1})
8792
return sparse_getindex(a, CartesianIndices(a)[I])

0 commit comments

Comments
 (0)