Skip to content

Commit b4e9812

Browse files
authored
Exploit special case broadcasting for *, - , +, /, \ (#94)
* Exploit special case broadcasting for *, - , +, /, \ * Use broadcast instead of map * Don't override broadcasted
1 parent aeb0664 commit b4e9812

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

src/blockbroadcast.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,24 @@ end
195195
bcf = Broadcast.flatten(Broadcasted{Nothing}(bc.f, bc.args, bc.axes))
196196
return Broadcasted{Style}(bcf.f, bcf.args, bcf.axes)
197197
end
198+
199+
200+
for op in (:+, :-, :*)
201+
@eval function copy(bc::Broadcasted{BlockStyle{N},<:Any,typeof($op),<:Tuple{<:BlockArray{<:Number,N}}}) where N
202+
(A,) = bc.args
203+
_BlockArray(broadcast(a -> broadcast($op, a), A.blocks), blocksizes(A))
204+
end
205+
end
206+
207+
for op in (:+, :-, :*, :/, :\)
208+
@eval begin
209+
function copy(bc::Broadcasted{BlockStyle{N},<:Any,typeof($op),<:Tuple{<:Number,<:BlockArray{<:Number,N}}}) where N
210+
x,A = bc.args
211+
_BlockArray(broadcast((x,a) -> broadcast($op, x, a), x, A.blocks), blocksizes(A))
212+
end
213+
function copy(bc::Broadcasted{BlockStyle{N},<:Any,typeof($op),<:Tuple{<:BlockArray{<:Number,N},<:Number}}) where N
214+
A,x = bc.args
215+
_BlockArray(broadcast((a,x) -> broadcast($op, a, x), A.blocks,x), blocksizes(A))
216+
end
217+
end
218+
end

test/test_blockbroadcast.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,16 @@ using BlockArrays, Test
9898
z2 = copy(z)
9999
@test (@. z = x + y + z; z) == (@. z2 = x2 + y2 + z2; z2)
100100
end
101+
102+
@testset "Special broadcast" begin
103+
v = mortar([1:3,4:7])
104+
@test broadcast(+, v) isa BlockVector{Int,Vector{UnitRange{Int}}}
105+
@test broadcast(+, v) == v
106+
@test broadcast(-, v) isa BlockVector{Int,Vector{StepRange{Int,Int}}}
107+
@test broadcast(-, v) == -v == -Vector(v)
108+
@test broadcast(+, v, 1) isa BlockVector{Int,Vector{UnitRange{Int}}}
109+
@test broadcast(+, v, 1) == Vector(v).+1
110+
@test broadcast(*, 2, v) isa BlockVector{Int,Vector{StepRange{Int,Int}}}
111+
@test broadcast(*, 2, v) == 2Vector(v)
112+
end
101113
end

0 commit comments

Comments
 (0)