Skip to content

Commit 10ea80c

Browse files
authored
Fix zero-dimensional conversion (#30)
1 parent 683b35d commit 10ea80c

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/map.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ArrayLayouts: LayoutArray
22
using BlockArrays: blockisequal
3-
using DerivableInterfaces: @interface, interface
3+
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
44
using LinearAlgebra: Adjoint, Transpose
55
using SparseArraysBase: SparseArraysBase, SparseArrayStyle
66

@@ -49,15 +49,30 @@ function reblock(
4949
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
5050
end
5151

52+
# `map!` specialized to zero-dimensional inputs.
53+
function map_zero_dim! end
54+
55+
@interface ::AbstractArrayInterface function map_zero_dim!(
56+
f, a_dest::AbstractArray, a_srcs::AbstractArray...
57+
)
58+
a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
59+
return a_dest
60+
end
61+
5262
# TODO: Move to `blocksparsearrayinterface/map.jl`.
5363
# TODO: Rewrite this so that it takes the blocking structure
5464
# made by combining the blocking of the axes (i.e. the blocking that
5565
# is used to determine `union_stored_blocked_cartesianindices(...)`).
5666
# `reblock` is a partial solution to that, but a bit ad-hoc.
5767
## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function.
58-
@interface ::AbstractBlockSparseArrayInterface function Base.map!(
68+
@interface interface::AbstractBlockSparseArrayInterface function Base.map!(
5969
f, a_dest::AbstractArray, a_srcs::AbstractArray...
6070
)
71+
if iszero(ndims(a_dest))
72+
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
73+
return a_dest
74+
end
75+
6176
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
6277
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
6378
BI_dest = blockindexrange(a_dest, I)

test/basics/test_basics.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,9 @@ arrayts = (Array, JLArray)
187187
@test iszero(@allowscalar(a[CartesianIndex()]))
188188
@test a[Block()] == dev(fill(0))
189189
@test iszero(@allowscalar(a[Block()][]))
190-
# Broken:
191-
## @test b[Block()[]] == 2
190+
@test @allowscalar(a[Block()[]]) == 0
191+
@test Array(a) isa Array{elt,0}
192+
@test Array(a) == fill(0)
192193
for b in (
193194
(b = copy(a); @allowscalar b[] = 2; b),
194195
(b = copy(a); @allowscalar b[CartesianIndex()] = 2; b),
@@ -206,6 +207,8 @@ arrayts = (Array, JLArray)
206207
@test b[Block()] == dev(fill(2))
207208
@test @allowscalar(b[Block()][]) == 2
208209
@test @allowscalar(b[Block()[]]) == 2
210+
@test Array(b) isa Array{elt,0}
211+
@test Array(b) == fill(2)
209212
end
210213
end
211214

0 commit comments

Comments
 (0)