@@ -27,50 +27,53 @@ function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill)
27
27
return Fill (val, broadcast_shape (axes (a), axes (b)))
28
28
end
29
29
30
- _broadcasted_eltype (a) = eltype (a)
31
- _broadcasted_eltype (a:: Base.Broadcast.Broadcasted ) = Base. Broadcast. combine_eltypes (a. f, a. args)
32
30
33
- _broadcasted_zeros (a, b) = Zeros {promote_type(_broadcasted_eltype(a), _broadcasted_eltype( b))} (broadcast_shape (axes (a), axes (b)))
34
- _broadcasted_ones (a, b) = Ones {promote_type(_broadcasted_eltype(a), _broadcasted_eltype( b))} (broadcast_shape (axes (a), axes (b)))
31
+ _broadcasted_zeros (f, a, b) = Zeros {Base.Broadcast.combine_eltypes(f, (a, b))} (broadcast_shape (axes (a), axes (b)))
32
+ _broadcasted_ones (f, a, b) = Ones {Base.Broadcast.combine_eltypes(f, (a, b))} (broadcast_shape (axes (a), axes (b)))
35
33
36
- broadcasted ( :: DefaultArrayStyle , :: typeof ( + ), a :: Zeros , b :: Zeros ) = _broadcasted_zeros (a, b)
37
- broadcasted ( :: DefaultArrayStyle , :: typeof ( + ), a :: Ones , b:: Zeros ) = _broadcasted_ones ( a, b)
38
- broadcasted ( :: DefaultArrayStyle , :: typeof ( + ), a :: Zeros , b:: Ones ) = _broadcasted_ones (a, b)
34
+ # TODO : remove at next breaking version
35
+ _broadcasted_zeros (a , b) = _broadcasted_zeros ( + , a, b)
36
+ _broadcasted_ones (a , b) = _broadcasted_ones (+ , a, b)
39
37
40
- broadcasted (:: DefaultArrayStyle , :: typeof (- ), a:: Zeros , b:: Zeros ) = _broadcasted_zeros (a, b)
41
- broadcasted (:: DefaultArrayStyle , :: typeof (- ), a:: Ones , b:: Zeros ) = _broadcasted_ones (a, b)
42
- broadcasted (:: DefaultArrayStyle , :: typeof (- ), a:: Ones , b:: Ones ) = _broadcasted_zeros ( a, b)
38
+ broadcasted (:: DefaultArrayStyle , :: typeof (+ ), a:: Zeros , b:: Zeros ) = _broadcasted_zeros (+ , a, b)
39
+ broadcasted (:: DefaultArrayStyle , :: typeof (+ ), a:: Ones , b:: Zeros ) = _broadcasted_ones (+ , a, b)
40
+ broadcasted (:: DefaultArrayStyle , :: typeof (+ ), a:: Zeros , b:: Ones ) = _broadcasted_ones ( + , a, b)
43
41
44
- broadcasted (:: DefaultArrayStyle , :: typeof (* ), a:: Zeros , b:: Zeros ) = _broadcasted_zeros (a, b)
42
+ broadcasted (:: DefaultArrayStyle , :: typeof (- ), a:: Zeros , b:: Zeros ) = _broadcasted_zeros (- , a, b)
43
+ broadcasted (:: DefaultArrayStyle , :: typeof (- ), a:: Ones , b:: Zeros ) = _broadcasted_ones (- , a, b)
44
+ broadcasted (:: DefaultArrayStyle , :: typeof (- ), a:: Ones , b:: Ones ) = _broadcasted_zeros (- , a, b)
45
45
46
+ broadcasted (:: DefaultArrayStyle , :: typeof (* ), a:: Zeros , b:: Zeros ) = _broadcasted_zeros (* , a, b)
47
+
48
+ # In following, need to restrict to <: Number as otherwise we cannot infer zero from type
49
+ # TODO : generalise to things like SVector
46
50
for op in (:* , :/ )
47
51
@eval begin
48
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Ones ) = _broadcasted_zeros (a, b)
49
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Fill ) = _broadcasted_zeros (a, b)
50
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Number ) = _broadcasted_zeros (a, b)
51
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: AbstractRange ) = _broadcasted_zeros (a, b)
52
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: AbstractArray ) = _broadcasted_zeros (a, b)
53
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Base.Broadcast.Broadcasted ) = _broadcasted_zeros (a, b)
54
- broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: Zeros , b:: AbstractRange ) = _broadcasted_zeros (a, b)
52
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Ones ) = _broadcasted_zeros ($ op, a, b)
53
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Fill{<:Number} ) = _broadcasted_zeros ($ op, a, b)
54
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Number ) = _broadcasted_zeros ($ op, a, b)
55
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: AbstractRange ) = _broadcasted_zeros ($ op, a, b)
56
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: AbstractArray{<:Number} ) = _broadcasted_zeros ($ op, a, b)
57
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Base.Broadcast.Broadcasted ) = _broadcasted_zeros ($ op, a, b)
58
+ broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: Zeros , b:: AbstractRange ) = _broadcasted_zeros ($ op, a, b)
55
59
end
56
60
end
57
61
58
62
for op in (:* , :\ )
59
63
@eval begin
60
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Ones , b:: Zeros ) = _broadcasted_zeros (a, b)
61
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Fill , b:: Zeros ) = _broadcasted_zeros (a, b)
62
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Number , b:: Zeros ) = _broadcasted_zeros (a, b)
63
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: AbstractRange , b:: Zeros ) = _broadcasted_zeros (a, b)
64
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: AbstractArray , b:: Zeros ) = _broadcasted_zeros (a, b)
65
- broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Base.Broadcast.Broadcasted , b:: Zeros ) = _broadcasted_zeros (a, b)
66
- broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: AbstractRange , b:: Zeros ) = _broadcasted_zeros (a, b)
64
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Ones , b:: Zeros ) = _broadcasted_zeros ($ op, a, b)
65
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Fill{<:Number} , b:: Zeros ) = _broadcasted_zeros ($ op, a, b)
66
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Number , b:: Zeros ) = _broadcasted_zeros ($ op, a, b)
67
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: AbstractRange , b:: Zeros ) = _broadcasted_zeros ($ op, a, b)
68
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: AbstractArray{<:Number} , b:: Zeros ) = _broadcasted_zeros ($ op, a, b)
69
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Base.Broadcast.Broadcasted , b:: Zeros ) = _broadcasted_zeros ($ op, a, b)
70
+ broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: AbstractRange , b:: Zeros ) = _broadcasted_zeros ($ op, a, b)
67
71
end
68
72
end
69
73
70
-
71
- broadcasted (:: DefaultArrayStyle , :: typeof (* ), a:: Ones , b:: Ones ) = _broadcasted_ones (a, b)
72
- broadcasted (:: DefaultArrayStyle , :: typeof (/ ), a:: Ones , b:: Ones ) = _broadcasted_ones (a, b)
73
- broadcasted (:: DefaultArrayStyle , :: typeof (\ ), a:: Ones , b:: Ones ) = _broadcasted_ones (a, b)
74
+ for op in (:* , :/ , :\ )
75
+ @eval broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Ones , b:: Ones ) = _broadcasted_ones ($ op, a, b)
76
+ end
74
77
75
78
# special case due to missing converts for ranges
76
79
_range_convert (:: Type{AbstractVector{T}} , a:: AbstractRange{T} ) where T = a
@@ -111,15 +114,19 @@ function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange{V}, b
111
114
return _range_convert (AbstractVector{promote_type (T,V)}, a)
112
115
end
113
116
117
+ # Need to prevent array-valued fills from broadcasting over entry
118
+ _broadcast_getindex_value (a:: AbstractFill{<:Number} ) = getindex_value (a)
119
+ _broadcast_getindex_value (a:: AbstractFill ) = Ref (getindex_value (a))
120
+
114
121
115
122
function broadcasted (:: DefaultArrayStyle{1} , :: typeof (* ), a:: AbstractFill , b:: AbstractRange )
116
123
broadcast_shape (axes (a), axes (b)) == axes (b) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $b to a Vector first." ))
117
- return broadcasted (* , getindex_value (a), b)
124
+ return broadcasted (* , _broadcast_getindex_value (a), b)
118
125
end
119
126
120
127
function broadcasted (:: DefaultArrayStyle{1} , :: typeof (* ), a:: AbstractRange , b:: AbstractFill )
121
128
broadcast_shape (axes (a), axes (b)) == axes (a) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $b to a Vector first." ))
122
- return broadcasted (* , a, getindex_value (b))
129
+ return broadcasted (* , a, _broadcast_getindex_value (b))
123
130
end
124
131
125
132
broadcasted (:: DefaultArrayStyle{N} , op, r:: AbstractFill{T,N} , x:: Number ) where {T,N} = Fill (op (getindex_value (r),x), size (r))
0 commit comments