Skip to content

Commit 7a07049

Browse files
authored
Fix broadcasting when FIll is array-valued (#115)
* Fix broadcasting when FIll is array-valued * Update runtests.jl
1 parent 224836b commit 7a07049

File tree

3 files changed

+50
-32
lines changed

3 files changed

+50
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.9.3"
3+
version = "0.9.4"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/fillbroadcast.jl

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,50 +27,53 @@ function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill)
2727
return Fill(val, broadcast_shape(axes(a), axes(b)))
2828
end
2929

30-
_broadcasted_eltype(a) = eltype(a)
31-
_broadcasted_eltype(a::Base.Broadcast.Broadcasted) = Base.Broadcast.combine_eltypes(a.f, a.args)
3230

33-
_broadcasted_zeros(a, b) = Zeros{promote_type(_broadcasted_eltype(a), _broadcasted_eltype(b))}(broadcast_shape(axes(a), axes(b)))
34-
_broadcasted_ones(a, b) = Ones{promote_type(_broadcasted_eltype(a), _broadcasted_eltype(b))}(broadcast_shape(axes(a), axes(b)))
31+
_broadcasted_zeros(f, a, b) = Zeros{Base.Broadcast.combine_eltypes(f, (a, b))}(broadcast_shape(axes(a), axes(b)))
32+
_broadcasted_ones(f, a, b) = Ones{Base.Broadcast.combine_eltypes(f, (a, b))}(broadcast_shape(axes(a), axes(b)))
3533

36-
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Zeros) = _broadcasted_zeros(a, b)
37-
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Ones, b::Zeros) = _broadcasted_ones(a, b)
38-
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Ones) = _broadcasted_ones(a, b)
34+
# TODO: remove at next breaking version
35+
_broadcasted_zeros(a, b) = _broadcasted_zeros(+, a, b)
36+
_broadcasted_ones(a, b) = _broadcasted_ones(+, a, b)
3937

40-
broadcasted(::DefaultArrayStyle, ::typeof(-), a::Zeros, b::Zeros) = _broadcasted_zeros(a, b)
41-
broadcasted(::DefaultArrayStyle, ::typeof(-), a::Ones, b::Zeros) = _broadcasted_ones(a, b)
42-
broadcasted(::DefaultArrayStyle, ::typeof(-), a::Ones, b::Ones) = _broadcasted_zeros(a, b)
38+
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Zeros) = _broadcasted_zeros(+, a, b)
39+
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Ones, b::Zeros) = _broadcasted_ones(+, a, b)
40+
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Ones) = _broadcasted_ones(+, a, b)
4341

44-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Zeros) = _broadcasted_zeros(a, b)
42+
broadcasted(::DefaultArrayStyle, ::typeof(-), a::Zeros, b::Zeros) = _broadcasted_zeros(-, a, b)
43+
broadcasted(::DefaultArrayStyle, ::typeof(-), a::Ones, b::Zeros) = _broadcasted_ones(-, a, b)
44+
broadcasted(::DefaultArrayStyle, ::typeof(-), a::Ones, b::Ones) = _broadcasted_zeros(-, a, b)
4545

46+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Zeros) = _broadcasted_zeros(*, a, b)
47+
48+
# In following, need to restrict to <: Number as otherwise we cannot infer zero from type
49+
# TODO: generalise to things like SVector
4650
for op in (:*, :/)
4751
@eval begin
48-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Ones) = _broadcasted_zeros(a, b)
49-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Fill) = _broadcasted_zeros(a, b)
50-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Number) = _broadcasted_zeros(a, b)
51-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::AbstractRange) = _broadcasted_zeros(a, b)
52-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::AbstractArray) = _broadcasted_zeros(a, b)
53-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros(a, b)
54-
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::Zeros, b::AbstractRange) = _broadcasted_zeros(a, b)
52+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Ones) = _broadcasted_zeros($op, a, b)
53+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Fill{<:Number}) = _broadcasted_zeros($op, a, b)
54+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Number) = _broadcasted_zeros($op, a, b)
55+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::AbstractRange) = _broadcasted_zeros($op, a, b)
56+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b)
57+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b)
58+
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::Zeros, b::AbstractRange) = _broadcasted_zeros($op, a, b)
5559
end
5660
end
5761

5862
for op in (:*, :\)
5963
@eval begin
60-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Ones, b::Zeros) = _broadcasted_zeros(a, b)
61-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Fill, b::Zeros) = _broadcasted_zeros(a, b)
62-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Number, b::Zeros) = _broadcasted_zeros(a, b)
63-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractRange, b::Zeros) = _broadcasted_zeros(a, b)
64-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractArray, b::Zeros) = _broadcasted_zeros(a, b)
65-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Base.Broadcast.Broadcasted, b::Zeros) = _broadcasted_zeros(a, b)
66-
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange, b::Zeros) = _broadcasted_zeros(a, b)
64+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Ones, b::Zeros) = _broadcasted_zeros($op, a, b)
65+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Fill{<:Number}, b::Zeros) = _broadcasted_zeros($op, a, b)
66+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Number, b::Zeros) = _broadcasted_zeros($op, a, b)
67+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractRange, b::Zeros) = _broadcasted_zeros($op, a, b)
68+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractArray{<:Number}, b::Zeros) = _broadcasted_zeros($op, a, b)
69+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Base.Broadcast.Broadcasted, b::Zeros) = _broadcasted_zeros($op, a, b)
70+
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange, b::Zeros) = _broadcasted_zeros($op, a, b)
6771
end
6872
end
6973

70-
71-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Ones, b::Ones) = _broadcasted_ones(a, b)
72-
broadcasted(::DefaultArrayStyle, ::typeof(/), a::Ones, b::Ones) = _broadcasted_ones(a, b)
73-
broadcasted(::DefaultArrayStyle, ::typeof(\), a::Ones, b::Ones) = _broadcasted_ones(a, b)
74+
for op in (:*, :/, :\)
75+
@eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::Ones, b::Ones) = _broadcasted_ones($op, a, b)
76+
end
7477

7578
# special case due to missing converts for ranges
7679
_range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a
@@ -111,15 +114,19 @@ function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange{V}, b
111114
return _range_convert(AbstractVector{promote_type(T,V)}, a)
112115
end
113116

117+
# Need to prevent array-valued fills from broadcasting over entry
118+
_broadcast_getindex_value(a::AbstractFill{<:Number}) = getindex_value(a)
119+
_broadcast_getindex_value(a::AbstractFill) = Ref(getindex_value(a))
120+
114121

115122
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange)
116123
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
117-
return broadcasted(*, getindex_value(a), b)
124+
return broadcasted(*, _broadcast_getindex_value(a), b)
118125
end
119126

120127
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill)
121128
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
122-
return broadcasted(*, a, getindex_value(b))
129+
return broadcasted(*, a, _broadcast_getindex_value(b))
123130
end
124131

125132
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), size(r))

test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,10 +576,21 @@ end
576576
@test_throws DimensionMismatch broadcast(*, Fill(1,3), 1:6)
577577
@test_throws DimensionMismatch broadcast(*, 1:6, Fill(1,3))
578578

579+
@testset "Number" begin
580+
@test broadcast(*, Zeros(5), 2) broadcast(*, 2, Zeros(5)) Zeros(5)
581+
end
582+
579583
@testset "Nested" begin
580584
@test randn(5) .\ rand(5) .* Zeros(5) Zeros(5)
581585
@test broadcast(*, Zeros(5), Base.Broadcast.broadcasted(\, randn(5), rand(5))) Zeros(5)
582586
end
587+
588+
@testset "array-valued" begin
589+
@test broadcast(*, Fill([1,2],3), 1:3) == broadcast(*, 1:3, Fill([1,2],3)) == broadcast(*, 1:3, fill([1,2],3))
590+
@test broadcast(*, Fill([1,2],3), Zeros(3)) == broadcast(*, Zeros(3), Fill([1,2],3)) == broadcast(*, zeros(3), fill([1,2],3))
591+
@test broadcast(*, Fill([1,2],3), Zeros(3)) isa Fill{Vector{Float64}}
592+
@test broadcast(*, [[1,2], [3,4,5]], Zeros(2)) == broadcast(*, Zeros(2), [[1,2], [3,4,5]]) == broadcast(*, zeros(2), [[1,2], [3,4,5]])
593+
end
583594
end
584595

585596
@testset "support Ref" begin

0 commit comments

Comments
 (0)