Skip to content

Commit ee31d8b

Browse files
committed
Update for latest Derive, SparseArraysBase
1 parent 63bb180 commit ee31d8b

File tree

9 files changed

+75
-59
lines changed

9 files changed

+75
-59
lines changed

examples/README.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ julia> Pkg.add(url="https://github.com/ITensor/BlockSparseArrays.jl")
3636
# ## Examples
3737

3838
using BlockArrays: BlockArrays, BlockedVector, Block, blockedrange
39-
using BlockSparseArrays: BlockSparseArray, block_stored_length
39+
using BlockSparseArrays: BlockSparseArray, block_storedlength
4040
using Test: @test, @test_broken
4141

4242
function main()
@@ -63,13 +63,13 @@ function main()
6363
]
6464
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)
6565

66-
@test block_stored_length(b) == 2
66+
@test block_storedlength(b) == 2
6767

6868
## Blocks with discontiguous underlying data
6969
d_blocks = randn.(nz_block_sizes)
7070
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)
7171

72-
@test block_stored_length(b) == 2
72+
@test block_storedlength(b) == 2
7373

7474
## Access a block
7575
@test b[Block(1, 1)] == d_blocks[1]
@@ -93,7 +93,7 @@ function main()
9393
@test b + b Array(b) + Array(b)
9494
@test b + b isa BlockSparseArray
9595
## TODO: Fix this, broken.
96-
@test_broken block_stored_length(b + b) == 2
96+
@test_broken block_storedlength(b + b) == 2
9797

9898
scaled_b = 2b
9999
@test scaled_b 2Array(b)

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using BlockArrays:
2121
findblockindex
2222
using Dictionaries: Dictionary, Indices
2323
using GradedUnitRanges: blockedunitrange_getindices, to_blockindices
24-
using SparseArraysBase: SparseArraysBase, stored_length, stored_indices
24+
using SparseArraysBase: SparseArraysBase, storedlength, eachstoredindex
2525

2626
# A return type for `blocks(array)` when `array` isn't blocked.
2727
# Represents a vector with just that single block.
@@ -269,7 +269,7 @@ tuple_oneto(n) = ntuple(identity, n)
269269
function block_reshape(a::AbstractArray, axes::Tuple{Vararg{AbstractUnitRange}})
270270
reshaped_blocks_a = reshape(blocks(a), blocklength.(axes))
271271
reshaped_a = similar(a, axes)
272-
for I in stored_indices(reshaped_blocks_a)
272+
for I in eachstoredindex(reshaped_blocks_a)
273273
block_size_I = map(i -> length(axes[i][Block(I[i])]), tuple_oneto(length(axes)))
274274
# TODO: Better converter here.
275275
reshaped_a[Block(Tuple(I))] = reshape(reshaped_blocks_a[I], block_size_I)
@@ -465,8 +465,8 @@ function findblocks(axis::AbstractUnitRange, range::AbstractUnitRange)
465465
return findblock(axis, first(range)):findblock(axis, last(range))
466466
end
467467

468-
function block_stored_indices(a::AbstractArray)
469-
return Block.(Tuple.(stored_indices(blocks(a))))
468+
function block_eachstoredindex(a::AbstractArray)
469+
return Block.(Tuple.(eachstoredindex(blocks(a))))
470470
end
471471

472472
_block(indices) = block(indices)
@@ -533,13 +533,13 @@ function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) wher
533533
return a
534534
end
535535

536-
function SparseArraysBase.stored_length(a::BlockView)
536+
function SparseArraysBase.storedlength(a::BlockView)
537537
# TODO: Store whether or not the block is stored already as
538538
# a Bool in `BlockView`.
539539
I = CartesianIndex(Int.(a.block))
540-
# TODO: Use `block_stored_indices`.
541-
if I stored_indices(blocks(a.array))
542-
return stored_length(blocks(a.array)[I])
540+
# TODO: Use `block_eachstoredindex`.
541+
if I eachstoredindex(blocks(a.array))
542+
return storedlength(blocks(a.array)[I])
543543
end
544544
return 0
545545
end
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
using BlockArrays: AbstractBlockArray, BlocksView
2-
using SparseArraysBase: SparseArraysBase, stored_length
2+
using SparseArraysBase: SparseArraysBase, storedlength
33

4-
function SparseArraysBase.stored_length(a::AbstractBlockArray)
5-
return sum(b -> stored_length(b), blocks(a); init=zero(Int))
4+
function SparseArraysBase.storedlength(a::AbstractBlockArray)
5+
return sum(b -> storedlength(b), blocks(a); init=zero(Int))
66
end
77

88
# TODO: Handle `BlocksView` wrapping a sparse array?
9-
function SparseArraysBase.storage_indices(a::BlocksView)
9+
function SparseArraysBase.eachstoredindex(a::BlocksView)
1010
return CartesianIndices(a)
1111
end

src/abstractblocksparsearray/map.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ end
6262
# is used to determine `union_stored_blocked_cartesianindices(...)`).
6363
# `reblock` is a partial solution to that, but a bit ad-hoc.
6464
# TODO: Move to `blocksparsearrayinterface/map.jl`.
65-
function SparseArraysBase.sparse_map!(
65+
## TODO: Make this an `@interface AbstractBlockSparseArray` function.
66+
function sparse_map!(
6667
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
6768
)
6869
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)

src/abstractblocksparsearray/sparsearrayinterface.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using BlockArrays: Block
2-
using SparseArraysBase: SparseArraysBase, sparse_storage, stored_indices
2+
using SparseArraysBase: SparseArraysBase, sparse_storage, eachstoredindex, storedlength
33

44
# Structure storing the block sparse storage
55
struct BlockSparseStorage{Arr<:AbstractBlockSparseArray}
@@ -11,7 +11,7 @@ function blockindex_to_cartesianindex(a::AbstractArray, blockindex)
1111
end
1212

1313
function Base.keys(s::BlockSparseStorage)
14-
stored_blockindices = Iterators.map(stored_indices(blocks(s.array))) do I
14+
stored_blockindices = Iterators.map(eachstoredindex(blocks(s.array))) do I
1515
block_axes = axes(blocks(s.array)[I])
1616
blockindices = Block(Tuple(I))[block_axes...]
1717
return Iterators.map(
@@ -29,10 +29,11 @@ function Base.iterate(s::BlockSparseStorage, args...)
2929
return iterate(values(s), args...)
3030
end
3131

32-
function SparseArraysBase.sparse_storage(a::AbstractBlockSparseArray)
33-
return BlockSparseStorage(a)
34-
end
32+
## TODO: Delete this, define `getstoredindex`, etc.
33+
## function SparseArraysBase.sparse_storage(a::AbstractBlockSparseArray)
34+
## return BlockSparseStorage(a)
35+
## end
3536

36-
function SparseArraysBase.stored_length(a::AnyAbstractBlockSparseArray)
37-
return sum(stored_length, sparse_storage(blocks(a)); init=zero(Int))
37+
function SparseArraysBase.storedlength(a::AnyAbstractBlockSparseArray)
38+
return sum(storedlength, sparse_storage(blocks(a)); init=zero(Int))
3839
end

src/abstractblocksparsearray/views.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ function BlockArrays.viewblock(
6868
a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N}
6969
) where {N}
7070
I = CartesianIndex(Int.(block))
71-
# TODO: Use `block_stored_indices`.
72-
if I stored_indices(blocks(a))
71+
# TODO: Use `block_eachstoredindex`.
72+
if I eachstoredindex(blocks(a))
7373
return blocks(a)[I]
7474
end
7575
return BlockView(a, block)
@@ -185,8 +185,8 @@ function BlockArrays.viewblock(
185185
block::Vararg{Block{1},N},
186186
) where {T,N}
187187
I = CartesianIndex(Int.(block))
188-
# TODO: Use `block_stored_indices`.
189-
if I stored_indices(blocks(a))
188+
# TODO: Use `block_eachstoredindex`.
189+
if I eachstoredindex(blocks(a))
190190
return blocks(a)[I]
191191
end
192192
return BlockView(parent(a), Block.(Base.reindex(parentindices(blocks(a)), Tuple(I))))

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using BlockArrays:
1313
blocks,
1414
findblockindex
1515
using LinearAlgebra: Adjoint, Transpose
16-
using SparseArraysBase: perm, iperm, stored_length, sparse_zero!
16+
using SparseArraysBase: perm, iperm, storedlength, sparse_zero!
1717

1818
blocksparse_blocks(a::AbstractArray) = error("Not implemented")
1919

@@ -136,8 +136,8 @@ function blocksparse_fill!(a::AbstractArray, value)
136136
return a
137137
end
138138

139-
function block_stored_length(a::AbstractArray)
140-
return stored_length(blocks(a))
139+
function block_storedlength(a::AbstractArray)
140+
return storedlength(blocks(a))
141141
end
142142

143143
# BlockArrays
@@ -169,18 +169,19 @@ function Base.getindex(
169169
blocks(parent(a.array))[_getindices(index, _invperm(a.array))...], _perm(a.array)
170170
)
171171
end
172-
function SparseArraysBase.stored_indices(a::SparsePermutedDimsArrayBlocks)
173-
return map(I -> _getindices(I, _perm(a.array)), stored_indices(blocks(parent(a.array))))
172+
function SparseArraysBase.eachstoredindex(a::SparsePermutedDimsArrayBlocks)
173+
return map(I -> _getindices(I, _perm(a.array)), eachstoredindex(blocks(parent(a.array))))
174174
end
175175
# TODO: Either make this the generic interface or define
176176
# `SparseArraysBase.sparse_storage`, which is used
177177
# to defined this.
178-
function SparseArraysBase.stored_length(a::SparsePermutedDimsArrayBlocks)
179-
return length(stored_indices(a))
180-
end
181-
function SparseArraysBase.sparse_storage(a::SparsePermutedDimsArrayBlocks)
182-
return error("Not implemented")
178+
function SparseArraysBase.storedlength(a::SparsePermutedDimsArrayBlocks)
179+
return length(eachstoredindex(a))
183180
end
181+
## TODO: Delete.
182+
## function SparseArraysBase.sparse_storage(a::SparsePermutedDimsArrayBlocks)
183+
## return error("Not implemented")
184+
## end
184185

185186
reverse_index(index) = reverse(index)
186187
reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))
@@ -240,25 +241,32 @@ function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) whe
240241
# TODO: Implement this properly.
241242
return true
242243
end
243-
function SparseArraysBase.stored_indices(a::SparseSubArrayBlocks)
244-
return stored_indices(view(blocks(parent(a.array)), blockrange(a)...))
244+
function SparseArraysBase.eachstoredindex(a::SparseSubArrayBlocks)
245+
return eachstoredindex(view(blocks(parent(a.array)), blockrange(a)...))
245246
end
246247
# TODO: Either make this the generic interface or define
247248
# `SparseArraysBase.sparse_storage`, which is used
248249
# to defined this.
249-
SparseArraysBase.stored_length(a::SparseSubArrayBlocks) = length(stored_indices(a))
250+
SparseArraysBase.storedlength(a::SparseSubArrayBlocks) = length(eachstoredindex(a))
250251

251252
## struct SparseSubArrayBlocksStorage{Array<:SparseSubArrayBlocks}
252253
## array::Array
253254
## end
254-
function SparseArraysBase.sparse_storage(a::SparseSubArrayBlocks)
255-
return map(I -> a[I], stored_indices(a))
256-
end
257255

258-
function SparseArraysBase.getindex_zero_function(a::SparseSubArrayBlocks)
259-
# TODO: Base it off of `getindex_zero_function(blocks(parent(a.array))`, but replace the
260-
# axes with `axes(a.array)`.
261-
return BlockZero(axes(a.array))
256+
## TODO: Delete.
257+
## function SparseArraysBase.sparse_storage(a::SparseSubArrayBlocks)
258+
## return map(I -> a[I], eachstoredindex(a))
259+
## end
260+
261+
## TODO: Delete.
262+
## function SparseArraysBase.getindex_zero_function(a::SparseSubArrayBlocks)
263+
## # TODO: Base it off of `getindex_zero_function(blocks(parent(a.array))`, but replace the
264+
## # axes with `axes(a.array)`.
265+
## return BlockZero(axes(a.array))
266+
## end
267+
268+
function SparseArraysBase.getunstoredindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
269+
error("Not implemented.")
262270
end
263271

264272
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)
@@ -271,4 +279,4 @@ function blocksparse_blocks(
271279
end
272280

273281
using BlockArrays: BlocksView
274-
SparseArraysBase.stored_length(a::BlocksView) = length(a)
282+
SparseArraysBase.storedlength(a::BlocksView) = length(a)

src/blocksparsearrayinterface/blockzero.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,32 @@ struct BlockZero{Axes}
1818
axes::Axes
1919
end
2020

21-
function (f::BlockZero)(a::AbstractArray, I)
22-
return f(eltype(a), I)
21+
function (f::BlockZero)(a::AbstractArray, I...)
22+
return f(eltype(a), I...)
2323
end
2424

25-
function (f::BlockZero)(arraytype::Type{<:SubArray{<:Any,<:Any,P}}, I) where {P}
26-
return f(P, I)
25+
function (f::BlockZero)(arraytype::Type{<:SubArray{<:Any,<:Any,P}}, I...) where {P}
26+
return f(P, I...)
2727
end
2828

29-
function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I)
29+
function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I::CartesianIndex)
30+
return f(arraytype, Tuple(I)...)
31+
end
32+
33+
function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I::Int...)
3034
# TODO: Make sure this works for sparse or block sparse blocks, immutable
3135
# blocks, diagonal blocks, etc.!
32-
blck_size = block_size(f.axes, Block(Tuple(I)))
36+
blck_size = block_size(f.axes, Block(I))
3337
blck_type = similartype(arraytype, blck_size)
3438
return fill!(blck_type(undef, blck_size), false)
3539
end
3640

3741
# Fallback so that `SparseArray` with scalar elements works.
38-
function (f::BlockZero)(blocktype::Type{<:Number}, I)
42+
function (f::BlockZero)(blocktype::Type{<:Number}, I...)
3943
return zero(blocktype)
4044
end
4145

4246
# Fallback to Array if it is abstract
43-
function (f::BlockZero)(arraytype::Type{AbstractArray{T,N}}, I) where {T,N}
44-
return f(Array{T,N}, I)
47+
function (f::BlockZero)(arraytype::Type{AbstractArray{T,N}}, I...) where {T,N}
48+
return f(Array{T,N}, I...)
4549
end

src/blocksparsearrayinterface/cat.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ using SparseArraysBase: SparseArraysBase, allocate_cat_output, sparse_cat!
33

44
# TODO: Maybe move to `SparseArraysBaseBlockArraysExt`.
55
# TODO: Handle dual graded unit ranges, for example in a new `SparseArraysBaseGradedUnitRangesExt`.
6-
function SparseArraysBase.axis_cat(
6+
## TODO: Add this back.
7+
## function SparseArraysBase.axis_cat(
8+
function axis_cat(
79
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
810
)
911
return blockedrange(vcat(blocklengths(a1), blocklengths(a2)))

0 commit comments

Comments
 (0)