Skip to content

Commit 0bb7f04

Browse files
committed
Fix nested broadcasting of BlockedArray
1 parent 90ddfa9 commit 0bb7f04

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BlockArrays"
22
uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
3-
version = "1.7.0"
3+
version = "1.7.1"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/blockbroadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ end
131131
@inline function Broadcast.materialize!(dest, bc::Broadcasted{BS}) where {NDims, BS<:AbstractBlockStyle{NDims}}
132132
dest_reshaped = ndims(dest) == NDims ? dest : reshape(dest, size(bc))
133133
bc2 = Broadcast.instantiate(
134-
Broadcast.Broadcasted{BS}(bc.f, bc.args,
135-
map(combine_blockaxes, axes(dest_reshaped), axes(bc))))
134+
Broadcast.flatten(Broadcast.Broadcasted{BS}(bc.f, bc.args,
135+
map(combine_blockaxes, axes(dest_reshaped), axes(bc)))))
136136
copyto!(dest_reshaped, bc2)
137137
return dest
138138
end

test/test_blockbroadcast.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,26 @@ using StaticArrays
6767
@test A .+ 1 .+ B == Vector(A) .+ 1 .+ B == Vector(A) .+ 1 .+ Matrix(B)
6868

6969
@testset "preserve structure" begin
70-
x = BlockedArray(1:6, Fill(3,2))
71-
@test x + x isa BlockedVector{Int,<:AbstractRange}
72-
@test 2x + x isa BlockedVector{Int,<:AbstractRange}
73-
@test 2 .* (x .+ 1) isa BlockedVector{Int,<:AbstractRange}
70+
x = BlockedArray(1:6, Fill(3,2))
71+
@test x + x isa BlockedVector{Int,<:AbstractRange}
72+
@test 2x + x isa BlockedVector{Int,<:AbstractRange}
73+
@test 2 .* (x .+ 1) isa BlockedVector{Int,<:AbstractRange}
74+
end
75+
76+
@testset "nested in-place broadcast" begin
77+
x = BlockedVector(randn(4), [2, 2])
78+
y = BlockedVector(randn(4), [2, 2])
79+
dest = copy(x)
80+
dest .+= 2 .* y
81+
@test dest x + 2y
82+
end
83+
84+
@testset "0-dim nested in-place broadcast" begin
85+
x = BlockedArray(randn(()))
86+
y = BlockedArray(randn(()))
87+
dest = copy(x)
88+
dest .+= 2 .* y
89+
@test dest x + 2y
7490
end
7591
end
7692

0 commit comments

Comments
 (0)