Skip to content

Commit 224836b

Browse files
authored
Nested special broadcast, throw error when special broadcast fails (#114)
* Use convert for ones broadcasted with matrices * Flag dodgy broadcast * Zeros broadcasting for more complicated broadcast * Update fillbroadcast.jl * increase coverage * Update runtests.jl
1 parent 41745e8 commit 224836b

File tree

3 files changed

+57
-18
lines changed

3 files changed

+57
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.9.2"
3+
version = "0.9.3"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/fillbroadcast.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@ function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill)
2727
return Fill(val, broadcast_shape(axes(a), axes(b)))
2828
end
2929

30-
function _broadcasted_zeros(a, b)
31-
return Zeros{promote_type(eltype(a), eltype(b))}(broadcast_shape(axes(a), axes(b)))
32-
end
33-
function _broadcasted_ones(a, b)
34-
return Ones{promote_type(eltype(a), eltype(b))}(broadcast_shape(axes(a), axes(b)))
35-
end
30+
_broadcasted_eltype(a) = eltype(a)
31+
_broadcasted_eltype(a::Base.Broadcast.Broadcasted) = Base.Broadcast.combine_eltypes(a.f, a.args)
32+
33+
_broadcasted_zeros(a, b) = Zeros{promote_type(_broadcasted_eltype(a), _broadcasted_eltype(b))}(broadcast_shape(axes(a), axes(b)))
34+
_broadcasted_ones(a, b) = Ones{promote_type(_broadcasted_eltype(a), _broadcasted_eltype(b))}(broadcast_shape(axes(a), axes(b)))
3635

3736
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Zeros) = _broadcasted_zeros(a, b)
3837
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Ones, b::Zeros) = _broadcasted_ones(a, b)
@@ -51,6 +50,7 @@ for op in (:*, :/)
5150
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Number) = _broadcasted_zeros(a, b)
5251
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::AbstractRange) = _broadcasted_zeros(a, b)
5352
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::AbstractArray) = _broadcasted_zeros(a, b)
53+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros(a, b)
5454
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::Zeros, b::AbstractRange) = _broadcasted_zeros(a, b)
5555
end
5656
end
@@ -62,6 +62,7 @@ for op in (:*, :\)
6262
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Number, b::Zeros) = _broadcasted_zeros(a, b)
6363
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractRange, b::Zeros) = _broadcasted_zeros(a, b)
6464
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractArray, b::Zeros) = _broadcasted_zeros(a, b)
65+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Base.Broadcast.Broadcasted, b::Zeros) = _broadcasted_zeros(a, b)
6566
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange, b::Zeros) = _broadcasted_zeros(a, b)
6667
end
6768
end
@@ -76,24 +77,48 @@ _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a
7677
_range_convert(::Type{AbstractVector{T}}, a::AbstractUnitRange) where T = convert(T,first(a)):convert(T,last(a))
7778
_range_convert(::Type{AbstractVector{T}}, a::AbstractRange) where T = convert(T,first(a)):step(a):convert(T,last(a))
7879

80+
81+
# TODO: replacing with the following will support more general broadcasting.
82+
# function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange)
83+
# broadcast_shape(axes(a), axes(b)) # check axes
84+
# r1 = b[1] * getindex_value(a)
85+
# T = typeof(r1)
86+
# if length(b) == 1 # Need a fill, but for type stability use StepRangeLen
87+
# StepRangeLen{T}(r1, zero(T), length(a))
88+
# else
89+
# StepRangeLen{T}(r1, convert(T, getindex_value(a) * step(b)), length(b))
90+
# end
91+
# end
92+
93+
# function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill)
94+
# broadcast_shape(axes(a), axes(b)) # check axes
95+
# r1 = a[1] * getindex_value(b)
96+
# T = typeof(r1)
97+
# if length(a) == 1 # Need a fill, but for type stability use StepRangeLen
98+
# StepRangeLen{T}(r1, zero(T), length(b))
99+
# else
100+
# StepRangeLen{T}(r1, convert(T, step(a) * getindex_value(b)), length(a))
101+
# end
102+
# end
103+
79104
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::Ones{T}, b::AbstractRange{V}) where {T,V}
80-
broadcast_shape(axes(a), axes(b)) # Check sizes are compatible.
105+
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
81106
return _range_convert(AbstractVector{promote_type(T,V)}, b)
82107
end
83108

84109
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange{V}, b::Ones{T}) where {T,V}
85-
broadcast_shape(axes(a), axes(b)) # Check sizes are compatible.
110+
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
86111
return _range_convert(AbstractVector{promote_type(T,V)}, a)
87112
end
88113

89114

90115
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange)
91-
broadcast_shape(axes(a), axes(b)) # Check sizes are compatible.
116+
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
92117
return broadcasted(*, getindex_value(a), b)
93118
end
94119

95120
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill)
96-
broadcast_shape(axes(a), axes(b)) # Check sizes are compatible.
121+
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
97122
return broadcasted(*, a, getindex_value(b))
98123
end
99124

test/runtests.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -537,13 +537,22 @@ end
537537
@test imag(Ones{ComplexF64}(10)) isa Zeros{Float64}
538538
@test imag(Ones{ComplexF64}(10,10)) isa Zeros{Float64}
539539

540-
rnge = range(-5.0, step=1.0, length=10)
541-
@test broadcast(*, Fill(5.0, 10), rnge) == broadcast(*, 5.0, rnge)
542-
@test broadcast(*, Zeros(10, 10), rnge) == zeros(10, 10)
543-
@test broadcast(*, rnge, Zeros(10, 10)) == zeros(10, 10)
544-
@test_throws DimensionMismatch broadcast(*, Fill(5.0, 11), rnge)
545-
@test broadcast(*, rnge, Fill(5.0, 10)) == broadcast(*, rnge, 5.0)
546-
@test_throws DimensionMismatch broadcast(*, rnge, Fill(5.0, 11))
540+
@testset "range broadcast" begin
541+
rnge = range(-5.0, step=1.0, length=10)
542+
@test broadcast(*, Fill(5.0, 10), rnge) == broadcast(*, 5.0, rnge)
543+
@test broadcast(*, Zeros(10, 10), rnge) == zeros(10, 10)
544+
@test broadcast(*, rnge, Zeros(10, 10)) == zeros(10, 10)
545+
@test broadcast(*, Ones{Int}(10), rnge) rnge
546+
@test broadcast(*, rnge, Ones{Int}(10)) rnge
547+
@test_throws DimensionMismatch broadcast(*, Fill(5.0, 11), rnge)
548+
@test broadcast(*, rnge, Fill(5.0, 10)) == broadcast(*, rnge, 5.0)
549+
@test_throws DimensionMismatch broadcast(*, rnge, Fill(5.0, 11))
550+
551+
# following should pass using alternative implementation in code
552+
deg = 5:5
553+
@test_throws ArgumentError @inferred(broadcast(*, Fill(5.0, 10), deg)) == broadcast(*, fill(5.0,10), deg)
554+
@test_throws ArgumentError @inferred(broadcast(*, deg, Fill(5.0, 10))) == broadcast(*, deg, fill(5.0,10))
555+
end
547556

548557
@testset "Special Zeros/Ones" begin
549558
@test broadcast(+,Zeros(5)) broadcast(-,Zeros(5)) Zeros(5)
@@ -566,6 +575,11 @@ end
566575
@test_throws DimensionMismatch broadcast(*, 1:6, Ones(3))
567576
@test_throws DimensionMismatch broadcast(*, Fill(1,3), 1:6)
568577
@test_throws DimensionMismatch broadcast(*, 1:6, Fill(1,3))
578+
579+
@testset "Nested" begin
580+
@test randn(5) .\ rand(5) .* Zeros(5) Zeros(5)
581+
@test broadcast(*, Zeros(5), Base.Broadcast.broadcasted(\, randn(5), rand(5))) Zeros(5)
582+
end
569583
end
570584

571585
@testset "support Ref" begin

0 commit comments

Comments
 (0)