Skip to content

Commit 4d1e02f

Browse files
authored
lower to standard broadcast for PseudoBlockArray (#193)
* lower to standard broadcast for PseudoBlockArray * Update blockbroadcast.jl * Update blockbroadcast.jl * Update test_blockbroadcast.jl * Update test_blockbroadcast.jl * v0.16.8
1 parent d4c8b6a commit 4d1e02f

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BlockArrays"
22
uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
3-
version = "0.16.7"
3+
version = "0.16.8"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/blockbroadcast.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,19 @@ function copyto!(dest::AbstractVector,
220220
end
221221
return _fast_blockbradcast_copyto!(dest, bc)
222222
end
223-
@inline function Broadcast.instantiate(bc::Broadcasted{Style}) where {Style <:AbstractBlockStyle}
223+
@inline function Broadcast.instantiate(bc::Broadcasted{Style}) where {Style <:BlockStyle}
224224
bcf = Broadcast.instantiate(Broadcast.flatten(Broadcasted{Nothing}(bc.f, bc.args, bc.axes)))
225225
return Broadcasted{Style}(bcf.f, bcf.args, bcf.axes)
226226
end
227227

228+
_removeblocks(a::Broadcasted) = broadcasted(a.f, map(_removeblocks,a.args)...)
229+
_removeblocks(a::PseudoBlockArray) = a.blocks
230+
_removeblocks(a::BlockSlice) = a.indices
231+
_removeblocks(a::Adjoint) = _removeblocks(parent(a))'
232+
_removeblocks(a::Transpose) = transpose(_removeblocks(parent(a)))
233+
_removeblocks(a::SubArray{<:Any,N,<:PseudoBlockArray}) where N = view(_removeblocks(parent(a)), map(_removeblocks, parentindices(a))...)
234+
_removeblocks(a) = a
235+
copy(bc::Broadcasted{PseudoBlockStyle{N}}) where N = PseudoBlockArray(materialize(_removeblocks(bc)), axes(bc))
228236

229237
for op in (:+, :-, :*)
230238
@eval function copy(bc::Broadcasted{BlockStyle{N},<:Any,typeof($op),<:Tuple{<:AbstractArray{<:Number,N}}}) where N

test/test_blockbroadcast.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
5050
@test axes(B .+ 1) == axes(B)
5151
@test axes(A .+ 1 .+ B) == axes(B)
5252
@test A .+ 1 .+ B == Vector(A) .+ 1 .+ B == Vector(A) .+ 1 .+ Matrix(B)
53+
54+
@testset "preserve structure" begin
55+
x = PseudoBlockArray(1:6, Fill(3,2))
56+
@test x + x isa PseudoBlockVector{Int,<:AbstractRange}
57+
@test 2x + x isa PseudoBlockVector{Int,<:AbstractRange}
58+
@test 2 .* (x .+ 1) isa PseudoBlockVector{Int,<:AbstractRange}
59+
end
5360
end
5461

5562
@testset "Mixed" begin
@@ -184,9 +191,11 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
184191
@test Base.BroadcastStyle(typeof(a')) isa BlockArrays.PseudoBlockStyle{2}
185192
@test Base.BroadcastStyle(typeof(b')) isa BlockArrays.BlockStyle{2}
186193

187-
@test exp.(a') == exp.(b') == exp.(Vector(a)')
194+
@test exp.(a') == exp.(b') == exp.(transpose(a)) == exp.(transpose(b)) == exp.(Vector(a)')
188195
@test exp.(a') isa PseudoBlockArray
196+
@test exp.(transpose(a)) isa PseudoBlockArray
189197
@test exp.(b') isa BlockArray
198+
@test exp.(transpose(b)) isa BlockArray
190199
end
191200

192201
@testset "subarray" begin

0 commit comments

Comments
 (0)