Skip to content

Commit 6244905

Browse files
authored
[BlockSparseArrays] Update to BlockArrays v1.1, fix some issues with nested views (#1503)
* [BlockSparseArrays] Update to BlockArrays v1.1, fix some issues with nested views * [NDTensors] Bump to v0.3.31, BlockArrays v1.1
1 parent b7cdc8d commit 6244905

File tree

5 files changed

+216
-105
lines changed

5 files changed

+216
-105
lines changed

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ end
215215

216216
function blockrange(
217217
axis::AbstractUnitRange,
218-
r::BlockVector{BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
218+
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
219219
)
220220
return map(b -> Block(b), blocks(r))
221221
end
@@ -271,7 +271,7 @@ end
271271
function blockindices(
272272
a::AbstractUnitRange,
273273
b::Block,
274-
r::BlockVector{BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
274+
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
275275
)
276276
# TODO: Change to iterate over `BlockRange(r)`
277277
# once https://github.com/JuliaArrays/BlockArrays.jl/issues/404

src/abstractblocksparsearray/view.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,3 @@ end
2727
function Base.view(a::BlockSparseArrayLike{<:Any,1}, index::Block{1})
2828
return blocksparse_view(a, index)
2929
end
30-
31-
function Base.view(a::BlockSparseArrayLike, indices::BlockIndexRange)
32-
return view(view(a, block(indices)), indices.indices...)
33-
end

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,18 @@ blocktype(a::BlockSparseArrayLike) = eltype(blocks(a))
120120
blocktype(arraytype::Type{<:BlockSparseArrayLike}) = eltype(blockstype(arraytype))
121121

122122
using ArrayLayouts: ArrayLayouts
123-
## function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::Vararg{Int,N}) where {N}
124-
## return ArrayLayouts.layout_getindex(a, I...)
125-
## end
126123
function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::CartesianIndices{N}) where {N}
127124
return ArrayLayouts.layout_getindex(a, I)
128125
end
129126
function Base.getindex(
130-
a::BlockSparseArrayLike{<:Any,N}, I::Vararg{AbstractUnitRange,N}
127+
a::BlockSparseArrayLike{<:Any,N}, I::Vararg{AbstractUnitRange{<:Integer},N}
131128
) where {N}
132129
return ArrayLayouts.layout_getindex(a, I...)
133130
end
134131
# TODO: Define `AnyBlockSparseMatrix`.
135-
function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitRange,2})
132+
function Base.getindex(
133+
a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitRange{<:Integer},2}
134+
)
136135
return ArrayLayouts.layout_getindex(a, I...)
137136
end
138137

@@ -199,7 +198,7 @@ end
199198

200199
# Needed by `BlockArrays` matrix multiplication interface
201200
function Base.similar(
202-
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractUnitRange}}
201+
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
203202
)
204203
return similar(arraytype, eltype(arraytype), axes)
205204
end
@@ -210,53 +209,45 @@ end
210209
# Delete once we drop support for older versions of Julia.
211210
function Base.similar(
212211
arraytype::Type{<:BlockSparseArrayLike},
213-
axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}},
214-
)
215-
return similar(arraytype, eltype(arraytype), axes)
216-
end
217-
218-
# Needed by `BlockArrays` matrix multiplication interface
219-
# Fixes ambiguity error with `BlockArrays.jl`.
220-
function Base.similar(
221-
arraytype::Type{<:BlockSparseArrayLike},
222-
axes::Tuple{AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
212+
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
223213
)
224214
return similar(arraytype, eltype(arraytype), axes)
225215
end
226216

227-
# Needed by `BlockArrays` matrix multiplication interface
228-
# Fixes ambiguity error with `BlockArrays.jl`.
217+
# Fixes ambiguity error with `BlockArrays`.
229218
function Base.similar(
230219
arraytype::Type{<:BlockSparseArrayLike},
231-
axes::Tuple{
232-
AbstractBlockedUnitRange,AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}
233-
},
220+
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
234221
)
235222
return similar(arraytype, eltype(arraytype), axes)
236223
end
237224

238-
# Needed by `BlockArrays` matrix multiplication interface
239-
# Fixes ambiguity error with `BlockArrays.jl`.
225+
# Fixes ambiguity error with `BlockArrays`.
240226
function Base.similar(
241227
arraytype::Type{<:BlockSparseArrayLike},
242228
axes::Tuple{
243-
AbstractUnitRange{Int},AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}
229+
AbstractUnitRange{<:Integer},
230+
AbstractBlockedUnitRange{<:Integer},
231+
Vararg{AbstractUnitRange{<:Integer}},
244232
},
245233
)
246234
return similar(arraytype, eltype(arraytype), axes)
247235
end
248236

249237
# Needed for disambiguation
250238
function Base.similar(
251-
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractBlockedUnitRange}}
239+
arraytype::Type{<:BlockSparseArrayLike},
240+
axes::Tuple{Vararg{AbstractBlockedUnitRange{<:Integer}}},
252241
)
253242
return similar(arraytype, eltype(arraytype), axes)
254243
end
255244

256245
# Needed by `BlockArrays` matrix multiplication interface
257246
# TODO: Define a `blocksparse_similar` function.
258247
function Base.similar(
259-
arraytype::Type{<:BlockSparseArrayLike}, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}}
248+
arraytype::Type{<:BlockSparseArrayLike},
249+
elt::Type,
250+
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
260251
)
261252
# TODO: Make generic for GPU, maybe using `blocktype`.
262253
# TODO: For non-block axes this should output `Array`.
@@ -265,7 +256,7 @@ end
265256

266257
# TODO: Define a `blocksparse_similar` function.
267258
function Base.similar(
268-
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}}
259+
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
269260
)
270261
# TODO: Make generic for GPU, maybe using `blocktype`.
271262
# TODO: For non-block axes this should output `Array`.
@@ -277,7 +268,9 @@ end
277268
function Base.similar(
278269
a::BlockSparseArrayLike,
279270
elt::Type,
280-
axes::Tuple{AbstractBlockedUnitRange,Vararg{AbstractBlockedUnitRange}},
271+
axes::Tuple{
272+
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
273+
},
281274
)
282275
# TODO: Make generic for GPU, maybe using `blocktype`.
283276
# TODO: For non-block axes this should output `Array`.
@@ -289,13 +282,37 @@ end
289282
function Base.similar(
290283
a::BlockSparseArrayLike,
291284
elt::Type,
292-
axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}},
285+
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
293286
)
294287
# TODO: Make generic for GPU, maybe using `blocktype`.
295288
# TODO: For non-block axes this should output `Array`.
296289
return BlockSparseArray{elt}(undef, axes)
297290
end
298291

292+
# Fixes ambiguity error with `BlockArrays`.
293+
function Base.similar(
294+
a::BlockSparseArrayLike,
295+
elt::Type,
296+
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
297+
)
298+
# TODO: Make generic for GPU, maybe using `blocktype`.
299+
# TODO: For non-block axes this should output `Array`.
300+
return BlockSparseArray{elt}(undef, axes)
301+
end
302+
303+
# Fixes ambiguity errors with BlockArrays.
304+
function Base.similar(
305+
a::BlockSparseArrayLike,
306+
elt::Type,
307+
axes::Tuple{
308+
AbstractUnitRange{<:Integer},
309+
AbstractBlockedUnitRange{<:Integer},
310+
Vararg{AbstractUnitRange{<:Integer}},
311+
},
312+
)
313+
return BlockSparseArray{elt}(undef, axes)
314+
end
315+
299316
# TODO: Define a `blocksparse_similar` function.
300317
# Fixes ambiguity error with `StaticArrays`.
301318
function Base.similar(

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,24 @@ function blocksparse_fill!(a::AbstractArray, value)
102102
return a
103103
end
104104
for b in BlockRange(a)
105-
a[b] .= value
105+
# We can't use:
106+
# ```julia
107+
# a[b] .= value
108+
# ```
109+
# since that would lead to a stack overflow,
110+
# because broadcasting calls `fill!`.
111+
112+
# TODO: Ideally we would use:
113+
# ```julia
114+
# @view!(a[b]) .= value
115+
# ```
116+
# but that doesn't work on `SubArray` right now.
117+
118+
# This line is needed to instantiate blocks
119+
# that aren't instantiated yet. Maybe
120+
# we can make this work without this line?
121+
blocks(a)[Int.(Tuple(b))...] = blocks(a)[Int.(Tuple(b))...]
122+
blocks(a)[Int.(Tuple(b))...] .= value
106123
end
107124
return a
108125
end
@@ -268,6 +285,10 @@ function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) w
268285
end
269286
function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N}
270287
parent_blocks = view(blocks(parent(a.array)), axes(a)...)
288+
# TODO: The following line is required to instantiate
289+
# uninstantiated blocks, maybe use `@view!` instead,
290+
# or some other code pattern.
291+
parent_blocks[I...] = parent_blocks[I...]
271292
return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] =
272293
value
273294
end

0 commit comments

Comments
 (0)