Skip to content

Commit aea5f34

Browse files
authored
Improve type inference in mortar (#212)
* type inference in block sizes * version bump to v0.16.17 * test for different-sized axes
1 parent 1a101c7 commit aea5f34

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BlockArrays"
22
uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
3-
version = "0.16.16"
3+
version = "0.16.17"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/blockarray.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,10 @@ function sizes_from_blocks(blocks::AbstractArray{<:Any, N}, _) where N
266266
error("All blocks must have ndims consistent with ndims = $N of `blocks` array.")
267267
end
268268
fullsizes = map!(size, Array{NTuple{N,Int}, N}(undef, size(blocks)), blocks)
269-
block_sizes = ntuple(ndims(blocks)) do i
270-
[s[i] for s in view(fullsizes, ntuple(j -> j == i ? (:) : 1, ndims(blocks))...)]
269+
fR = reinterpret(reshape, Int, fullsizes)
270+
stfR = strides(fR)
271+
block_sizes = ntuple(N) do i
272+
fR[range(i, step = stfR[i+1], length=size(fullsizes, i))]
271273
end
272274
checksizes(fullsizes, block_sizes)
273275
return block_sizes

test/test_blockarrays.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,22 @@ end
125125
@test_throws DimensionMismatch BlockArray([1,2,3],[1,1])
126126

127127
@testset "mortar" begin
128-
@testset for sizes in [(1:3,), (1:3, 1:3), (1:3, 1:3, 1:3)]
128+
@testset for sizes in [(1:3,), (1:3, 1:4), (1:3, 1:4, 1:2)]
129129
dims = sum.(sizes)
130-
A = BlockArray(copy(reshape(1:prod(dims), dims)), sizes...)
131-
@test mortar(A.blocks) == A
130+
A = @inferred BlockArray(copy(reshape(1:prod(dims), dims)), sizes...)
131+
@test @inferred mortar(A.blocks) == A
132+
if length(dims) == 2
133+
# compare with hvcat
134+
rows = ntuple(_->length(sizes[2]), length(sizes[1]))
135+
@test mortar(A.blocks) == hvcat(rows, permutedims(A.blocks)...)
136+
end
132137
end
133138

134-
ret = mortar([spzeros(2), spzeros(3)])
139+
ret = @inferred mortar([spzeros(2), spzeros(3)])
135140
@test eltype(ret.blocks) <: SparseVector
136141
@test axes(ret) == (blockedrange([2, 3]),)
137142

138-
ret = mortar(
143+
ret = @inferred mortar(
139144
(spzeros(1, 3), spzeros(1, 4)),
140145
(spzeros(2, 3), spzeros(2, 4)),
141146
(spzeros(5, 3), spzeros(5, 4)),
@@ -157,6 +162,17 @@ end
157162
(zeros(2, 3), zeros(111, 222)),
158163
)
159164
end
165+
166+
@testset "sizes_from_blocks" begin
167+
blocks = reshape([rand(2,2), zeros(1,2),
168+
zeros(2,3), rand(1,3)], 2, 2);
169+
@test @inferred BlockArrays.sizes_from_blocks(blocks) == ([2,1], [2,3])
170+
blocks = reshape(
171+
[rand(2,2), zeros(1,2), zeros(4,2),
172+
zeros(2,3), rand(1,3), zeros(4,3),
173+
zeros(2,1), zeros(1,1), rand(4,1)], 3, 3);
174+
@test @inferred BlockArrays.sizes_from_blocks(blocks) == ([2, 1, 4], [2, 3, 1])
175+
end
160176
end
161177

162178
@testset "BlockVector" begin

0 commit comments

Comments
 (0)