Skip to content

Commit cb41530

Browse files
willtebbuttdlfivefifty
authored andcommitted
Binary broadcasting fix + tidy up (#66)
* Update binary broadcasting + gitignore * Reorganise a bit
1 parent 1a9c871 commit cb41530

File tree

3 files changed

+65
-38
lines changed

3 files changed

+65
-38
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.jl.mem
44
deps/deps.jl
55
.DS_Store
6+
Manifest.toml

src/fillbroadcast.jl

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,54 @@
1-
for op in (:+, :-)
2-
@eval broadcasted(::DefaultArrayStyle{N}, ::typeof($op), r1::AbstractFill{T,N}, r2::AbstractFill{V,N}) where {T,V,N} =
3-
$op(r1, r2)
4-
end
51

6-
function broadcasted(::DefaultArrayStyle{N}, ::typeof(*), a::Zeros{T,N}, b::Zeros{V,N}) where {T,V,N}
7-
axes(a) axes(b) && throw(DimensionMismatch("dimensions must match."))
8-
Zeros{promote_type(T,V)}(axes(a))
9-
end
2+
### Unary broadcasting
103

11-
function _broadcasted_mul(a::AbstractArray{T}, b::Zeros{V}) where {T,V}
12-
axes(a) axes(b) && throw(DimensionMismatch("dimensions must match."))
13-
Zeros{promote_type(T,V)}(axes(a))
14-
end
15-
function broadcasted(::DefaultArrayStyle{N}, ::typeof(*), a::AbstractArray{T,N}, b::Zeros{V,N}) where {T,V,N}
16-
return _broadcasted_mul(a, b)
17-
end
18-
function broadcasted(::DefaultArrayStyle{N}, ::typeof(*), a::AbstractFill{T,N}, b::Zeros{V,N}) where {T,V,N}
19-
return _broadcasted_mul(a, b)
4+
function broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N}
5+
return Fill(op(getindex_value(r)), size(r))
206
end
217

22-
function _broadcasted_mul(a::Zeros{T}, b::AbstractArray{V}) where {T,V}
23-
axes(a) axes(b) && throw(DimensionMismatch("dimensions must match."))
24-
Zeros{promote_type(T,V)}(axes(a))
8+
9+
### Binary broadcasting
10+
11+
function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill)
12+
val = op(getindex_value(a), getindex_value(b))
13+
return Fill(val, broadcast_shape(size(a), size(b)))
2514
end
26-
function broadcasted(::DefaultArrayStyle{N}, ::typeof(*), a::Zeros{T,N}, b::AbstractArray{V,N}) where {T,V,N}
27-
_broadcasted_mul(a, b)
15+
16+
function _broadcasted_zeros(a, b)
17+
return Zeros{promote_type(eltype(a), eltype(b))}(broadcast_shape(size(a), size(b)))
2818
end
29-
function broadcasted(::DefaultArrayStyle{N}, ::typeof(*), a::Zeros{T,N}, b::AbstractFill{V,N}) where {T,V,N}
30-
_broadcasted_mul(a, b)
19+
function _broadcasted_ones(a, b)
20+
return Ones{promote_type(eltype(a), eltype(b))}(broadcast_shape(size(a), size(b)))
3121
end
3222

23+
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Zeros) = _broadcasted_zeros(a, b)
24+
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Ones, b::Zeros) = _broadcasted_ones(a, b)
25+
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Ones) = _broadcasted_ones(a, b)
26+
27+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Zeros) = _broadcasted_zeros(a, b)
28+
29+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Ones) = _broadcasted_zeros(a, b)
30+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Fill) = _broadcasted_zeros(a, b)
3331
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::AbstractRange)
34-
return Zeros{promote_type(eltype(a), eltype(b))}(broadcast_shape(size(a), size(b)))
32+
return _broadcasted_zeros(a, b)
33+
end
34+
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::AbstractArray)
35+
return _broadcasted_zeros(a, b)
3536
end
3637

38+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Ones, b::Zeros) = _broadcasted_zeros(a, b)
39+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Fill, b::Zeros) = _broadcasted_zeros(a, b)
3740
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractRange, b::Zeros)
38-
return Zeros{promote_type(eltype(a), eltype(b))}(broadcast_shape(size(a), size(b)))
41+
return _broadcasted_zeros(a, b)
3942
end
43+
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractArray, b::Zeros)
44+
return _broadcasted_zeros(a, b)
45+
end
46+
47+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Ones, b::Ones) = _broadcasted_ones(a, b)
48+
broadcasted(::DefaultArrayStyle, ::typeof(/), a::Ones, b::Ones) = _broadcasted_ones(a, b)
49+
broadcasted(::DefaultArrayStyle, ::typeof(\), a::Ones, b::Ones) = _broadcasted_ones(a, b)
50+
51+
4052

4153
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractFill, b::AbstractRange)
4254
broadcast_shape(size(a), size(b)) # Check sizes are compatible.
@@ -48,19 +60,7 @@ function broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractRange, b::Abst
4860
return broadcasted(*, a, getindex_value(b))
4961
end
5062

51-
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N} = Fill(op(getindex_value(r)), size(r))
5263
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), size(r))
5364
broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = Fill(op(x, getindex_value(r)), size(r))
5465
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = Fill(op(getindex_value(r),x[]), size(r))
5566
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), size(r))
56-
function broadcasted(::DefaultArrayStyle{N}, op, r1::AbstractFill{T,N}, r2::AbstractFill{V,N}) where {T,V,N}
57-
size(r1)  size(r2) && throw(DimensionMismatch("dimensions must match."))
58-
Fill(op(getindex_value(r1),getindex_value(r2)), size(r1))
59-
end
60-
61-
for op in (:*, :/, :\)
62-
@eval function broadcasted(::DefaultArrayStyle{N}, ::typeof($op), r1::Ones{T,N}, r2::Ones{V,N}) where {T,V,N}
63-
size(r1)  size(r2) && throw(DimensionMismatch("dimensions must match."))
64-
Ones{promote_type(T,V)}(size(r1))
65-
end
66-
end

test/runtests.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,32 @@ end
460460
@test y .+ y Fill(2.0,5,5)
461461
@test y .* y y ./ y y .\ y y
462462

463+
rng = MersenneTwister(123456)
464+
sizes = [(5, 4), (5, 1), (1, 4), (1, 1), (5,)]
465+
for sx in sizes, sy in sizes
466+
x, y = Fill(randn(rng), sx), Fill(randn(rng), sy)
467+
x_one, y_one = Ones(sx), Ones(sy)
468+
x_zero, y_zero = Zeros(sx), Zeros(sy)
469+
x_dense, y_dense = randn(rng, sx), randn(rng, sy)
470+
471+
for x in [x, x_one, x_zero, x_dense], y in [y, y_one, y_zero, y_dense]
472+
@test x .+ y == collect(x) .+ collect(y)
473+
end
474+
@test x_zero .+ y_zero isa Zeros
475+
@test x_zero .+ y_one isa Ones
476+
@test x_one .+ y_zero isa Ones
477+
478+
for x in [x, x_one, x_zero, x_dense], y in [y, y_one, y_zero, y_dense]
479+
@test x .* y == collect(x) .* collect(y)
480+
end
481+
for x in [x, x_one, x_zero, x_dense]
482+
@test x .* y_zero isa Zeros
483+
end
484+
for y in [y, y_one, y_zero, y_dense]
485+
@test x_zero .* y isa Zeros
486+
end
487+
end
488+
463489
@test Zeros{Int}(5) .+ Zeros(5) isa Zeros{Float64}
464490

465491
rnge = range(-5.0, step=1.0, length=10)

0 commit comments

Comments
 (0)