Skip to content

Commit 836da5c

Browse files
authored
Fix zero dimensional BlockArray and BlockedArray (#410)
* add broken tests * fix zerodim constructor and getindex * fix blockcheckbounds * fix tests * fix views * fix setindex * fix view * fix Array * fix BlockedArray view * test linalg * fix linear algebra * test for product * test ^ * same behavior as Array for .* and .^ * fix reshape * test view(::ReshapedArray) * fix getindex(::Block{0}) * revert unneeded changes
1 parent 4d283f6 commit 836da5c

File tree

8 files changed

+122
-3
lines changed

8 files changed

+122
-3
lines changed

src/abstractblockarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ end
101101

102102
blockcheckbounds(A::AbstractArray{T, N}, i::Block{N}) where {T,N} = blockcheckbounds(A, i.n...)
103103
blockcheckbounds(A::AbstractArray{T, N}, i::Vararg{Block{1},N}) where {T,N} = blockcheckbounds(A, Int.(i)...)
104+
blockcheckbounds(::AbstractArray{T, 0}) where {T} = true
104105
blockcheckbounds(A::AbstractVector{T}, i::Block{1}) where {T} = blockcheckbounds(A, Int(i))
105106

106107
"""
@@ -186,6 +187,7 @@ viewblock(block_arr, block) = Base.invoke(view, Tuple{AbstractArray, Any}, block
186187
blkind = BlockRange(blocksize(block_arr))[Int(block)]
187188
view(block_arr, blkind)
188189
end
190+
@inline view(zerodim::AbstractBlockArray{<:Any,0}) = view(zerodim.blocks[])
189191
@inline view(block_arr::AbstractBlockVector, block::Block{1}) = viewblock(block_arr, block)
190192
@propagate_inbounds view(block_arr::AbstractBlockArray, block::Block{1}...) = view(block_arr, Block(block))
191193

src/blockarray.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,12 @@ function BlockArray{T}(arr::AbstractArray{T, N}, baxes::Tuple{Vararg{AbstractUni
265265
return _BlockArray(blocks, baxes)
266266
end
267267

268+
function BlockArray{T}(arr::AbstractArray{T, 0}, ::Tuple{}) where T
269+
blocks = Array{Array{T, 0},0}(undef)
270+
fill!(blocks, arr)
271+
return _BlockArray(blocks)
272+
end
273+
268274
BlockArray{T}(arr::AbstractArray{<:Any, N}, baxes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}}) where {T,N} =
269275
BlockArray{T}(convert(AbstractArray{T, N}, arr), baxes)
270276

@@ -463,12 +469,21 @@ const OffsetAxis = Union{Integer, UnitRange, Base.OneTo, Base.IdentityUnitRange}
463469
return v
464470
end
465471

472+
@inline function getindex(block_arr::BlockArray{T, 0}) where T
473+
return blocks(block_arr)[][]
474+
end
475+
466476
@inline function setindex!(block_arr::BlockArray{T, N}, v, i::Vararg{Integer, N}) where {T,N}
467477
@boundscheck checkbounds(block_arr, i...)
468478
@inbounds block_arr[findblockindex.(axes(block_arr), i)...] = v
469479
return block_arr
470480
end
471481

482+
@inline function setindex!(block_arr::BlockArray{<:Any, 0}, v)
483+
blocks(block_arr)[][] = v
484+
end
485+
486+
472487
############
473488
# Indexing #
474489
############
@@ -527,6 +542,11 @@ end
527542
########
528543
# Misc #
529544
########
545+
function Base.Array(zerodim::BlockArray{T, 0}) where {T}
546+
arr = Array{T}(undef)
547+
arr[] = zerodim[]
548+
return arr
549+
end
530550

531551
function Base.Array(block_array::BlockArray{T, N, R}) where {T,N,R}
532552
arr = Array{eltype(T)}(undef, size(block_array))
@@ -547,6 +567,8 @@ end
547567
# Temporary work around
548568
Base.reshape(block_array::BlockArray, axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}}) where N =
549569
reshape(BlockedArray(block_array), axes)
570+
Base.reshape(block_array::BlockArray, ::Tuple{}) =
571+
reshape(BlockedArray(block_array), ()) # zerodim
550572
Base.reshape(block_array::BlockArray, dims::Tuple{Int,Vararg{Int}}) =
551573
reshape(BlockedArray(block_array), dims)
552574
Base.reshape(block_array::BlockArray, axes::Tuple{Union{Integer,Base.OneTo}, Vararg{Union{Integer,Base.OneTo}}}) =

src/blockedarray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ const BlockedVector{T} = BlockedArray{T, 1}
9999
const BlockedVecOrMat{T} = Union{BlockedMatrix{T}, BlockedVector{T}}
100100

101101
# Auxiliary outer constructors
102+
BlockedArray(x::Number, ::Tuple{}) = x # zero dimensional
102103
@inline BlockedArray(blocks::R, baxes::BS) where {T,N,R<:AbstractArray{T,N},BS<:Tuple{Vararg{AbstractUnitRange{<:Integer},N}}} =
103104
BlockedArray{T, N, R,BS}(blocks, baxes)
104105

@@ -240,6 +241,7 @@ end
240241
############
241242
# Indexing #
242243
############
244+
@inline view(block_arr::BlockedArray{<:Any, 0}) = view(block_arr.blocks)
243245

244246
@inline function viewblock(block_arr::BlockedArray, block)
245247
range = getindex.(axes(block_arr), Block.(block.n))
@@ -300,6 +302,8 @@ Base.reshape(parent::BlockedArray, shp::Tuple{Union{Int,Base.OneTo}, Vararg{Unio
300302
reshape(parent, Base.to_shape(shp))
301303
Base.reshape(parent::BlockedArray, dims::Tuple{Int,Vararg{Int}}) =
302304
Base._reshape(parent, dims)
305+
Base.reshape(block_array::BlockedArray, ::Tuple{}) =
306+
_blocked_reshape(block_array, ()) # zero dim
303307

304308
"""
305309
resize!(a::BlockedVector, N::Block) -> BlockedVector

src/blockindices.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ BlockIndexRange(block::Block{N}, inds::Vararg{AbstractUnitRange{<:Integer},N}) w
202202

203203
block(R::BlockIndexRange) = R.block
204204

205+
getindex(::Block{0}) = Block()
205206
getindex(B::Block{N}, inds::Vararg{Integer,N}) where N = BlockIndex(B,inds)
206207
getindex(B::Block{N}, inds::Vararg{AbstractUnitRange{<:Integer},N}) where N = BlockIndexRange(B,inds)
207208
getindex(B::Block{1}, inds::Colon) = B

src/views.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,13 @@ end
8888
return reshape(view(A.parent, I[1:M]...), Val(N))
8989
end
9090

91-
@propagate_inbounds function Base.unsafe_view(
92-
A::Array{<:Any, N},
93-
I::Vararg{BlockSlice{<:BlockIndexRange{1}}, N}) where {N}
91+
@propagate_inbounds function Base.unsafe_view(
92+
A::Array,
93+
I1::BlockSlice{<:BlockIndexRange{1}},
94+
Is::Vararg{BlockSlice{<:BlockIndexRange{1}}},
95+
)
96+
I = (I1, Is...)
97+
@assert ndims(A) == length(I)
9498
return view(A, map(x -> x.indices, I)...)
9599
end
96100

test/test_blockarrays.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,61 @@ end
318318
@test all(isone, M)
319319
end
320320

321+
@testset "zero dim" begin
322+
zerodim = ones()
323+
@test view(zerodim) isa AbstractArray{Float64, 0} # check no type-piracy
324+
325+
ret = BlockArray{Float64}(undef)
326+
@test ret isa BlockArray{Float64, 0}
327+
fill!(ret, 0)
328+
@test size(ret) == ()
329+
@test all(iszero, ret)
330+
@test ret[Block()] == zeros()
331+
@test ret[Block()[]] == zeros()
332+
@test ret[] == 0
333+
@test view(ret, Block()) == zeros()
334+
@test Array(ret) == zeros()
335+
ret[] = 1
336+
@test ret[] == 1
337+
@test view(ret) == ones()
338+
view(ret)[] = 0
339+
@test ret[] == 0
340+
341+
ret = BlockArrays.BlockArray(zeros())
342+
@test ret isa BlockArray{Float64, 0}
343+
@test size(ret) == ()
344+
@test all(iszero, ret)
345+
@test ret[Block()] == zeros()
346+
347+
ret = BlockArrays.BlockArray(zeros(1,1))
348+
@test reshape(ret, ()) isa AbstractBlockArray{Float64, 0} # may be BlockedArray
349+
@test size(reshape(ret, ())) == ()
350+
351+
ret = BlockedArray{Float64}(undef)
352+
@test ret isa BlockedArray{Float64, 0}
353+
fill!(ret, 0)
354+
@test size(ret) == ()
355+
@test all(iszero, ret)
356+
@test ret[] == 0
357+
@test ret[Block()] == zeros()
358+
@test ret[Block()[]] == zeros()
359+
@test Array(ret) == zeros()
360+
ret[] = 1
361+
@test ret[] == 1
362+
@test view(ret) == ones()
363+
view(ret)[] = 0
364+
@test ret[] == 0
365+
366+
ret = BlockedArray(zeros())
367+
@test size(ret) == ()
368+
@test all(iszero, ret)
369+
@test ret[Block()] == zeros()
370+
371+
ret = BlockArrays.BlockedArray(zeros(1,1))
372+
@test reshape(ret, ()) isa BlockedArray{Float64, 0}
373+
@test size(reshape(ret, ())) == ()
374+
end
375+
321376
@testset "BlockVector" begin
322377
a_data = [1,2,3]
323378
a = BlockVector(a_data,[1,2])

test/test_blocklinalg.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,34 @@ import ArrayLayouts: DenseRowMajor, ColumnMajor, StridedLayout
77
bview(a, b) = Base.invoke(view, Tuple{AbstractArray,Any}, a, b)
88

99
@testset "Linear Algebra" begin
10+
@testset "zerodim" begin
11+
a = BlockArray{Float64}(2*ones())
12+
@test 2a isa BlockArray{Float64,0}
13+
@test (2a)[] == 4
14+
@test a + a isa BlockArray{Float64,0}
15+
@test a + a == 2a
16+
@test norm(a) == 2
17+
18+
# same behavior as Array
19+
@test a .* a isa Float64
20+
@test a .* a == 4
21+
@test a .^ 2 isa Float64
22+
@test a .^ 2 == 4
23+
24+
a = BlockedArray{Float64}(2*ones())
25+
@test 2a isa BlockedArray{Float64,0}
26+
@test (2a)[] == 4
27+
@test a + a isa BlockedArray{Float64,0}
28+
@test a + a == 2a
29+
@test norm(a) == 2
30+
31+
# same behavior as Array
32+
@test a .* a isa Float64
33+
@test a .* a == 4
34+
@test a .^ 2 isa Float64
35+
@test a .^ 2 == 4
36+
end
37+
1038
@testset "BlockArray scalar * matrix" begin
1139
A = BlockArray{Float64}(randn(6,6), fill(2,3), 1:3)
1240
@test 2A == A*2 == 2Matrix(A)

test/test_blockviews.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ bview(a, b) = Base.invoke(view, Tuple{AbstractArray,Any}, a, b)
2121
@test collect(b) == [2,3]
2222
@test b[1] == 2
2323
@test b[1:2] == 2:3
24+
25+
rba = reshape(BlockedArray(collect(1:4),[2,2]), (2,2))
26+
@test view(rba, Block(1,1)[1:1,1:1]) == ones(1,1)
2427
end
2528

2629
@testset "block view" begin

0 commit comments

Comments
 (0)