@@ -27,12 +27,11 @@ function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill)
27
27
return Fill (val, broadcast_shape (axes (a), axes (b)))
28
28
end
29
29
30
- function _broadcasted_zeros (a, b)
31
- return Zeros {promote_type(eltype(a), eltype(b))} (broadcast_shape (axes (a), axes (b)))
32
- end
33
- function _broadcasted_ones (a, b)
34
- return Ones {promote_type(eltype(a), eltype(b))} (broadcast_shape (axes (a), axes (b)))
35
- end
30
+ _broadcasted_eltype (a) = eltype (a)
31
+ _broadcasted_eltype (a:: Base.Broadcast.Broadcasted ) = Base. Broadcast. combine_eltypes (a. f, a. args)
32
+
33
+ _broadcasted_zeros (a, b) = Zeros {promote_type(_broadcasted_eltype(a), _broadcasted_eltype(b))} (broadcast_shape (axes (a), axes (b)))
34
+ _broadcasted_ones (a, b) = Ones {promote_type(_broadcasted_eltype(a), _broadcasted_eltype(b))} (broadcast_shape (axes (a), axes (b)))
36
35
37
36
broadcasted (:: DefaultArrayStyle , :: typeof (+ ), a:: Zeros , b:: Zeros ) = _broadcasted_zeros (a, b)
38
37
broadcasted (:: DefaultArrayStyle , :: typeof (+ ), a:: Ones , b:: Zeros ) = _broadcasted_ones (a, b)
@@ -51,6 +50,7 @@ for op in (:*, :/)
51
50
broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Number ) = _broadcasted_zeros (a, b)
52
51
broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: AbstractRange ) = _broadcasted_zeros (a, b)
53
52
broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: AbstractArray ) = _broadcasted_zeros (a, b)
53
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Zeros , b:: Base.Broadcast.Broadcasted ) = _broadcasted_zeros (a, b)
54
54
broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: Zeros , b:: AbstractRange ) = _broadcasted_zeros (a, b)
55
55
end
56
56
end
@@ -62,6 +62,7 @@ for op in (:*, :\)
62
62
broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Number , b:: Zeros ) = _broadcasted_zeros (a, b)
63
63
broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: AbstractRange , b:: Zeros ) = _broadcasted_zeros (a, b)
64
64
broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: AbstractArray , b:: Zeros ) = _broadcasted_zeros (a, b)
65
+ broadcasted (:: DefaultArrayStyle , :: typeof ($ op), a:: Base.Broadcast.Broadcasted , b:: Zeros ) = _broadcasted_zeros (a, b)
65
66
broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: AbstractRange , b:: Zeros ) = _broadcasted_zeros (a, b)
66
67
end
67
68
end
@@ -76,24 +77,48 @@ _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a
76
77
_range_convert (:: Type{AbstractVector{T}} , a:: AbstractUnitRange ) where T = convert (T,first (a)): convert (T,last (a))
77
78
_range_convert (:: Type{AbstractVector{T}} , a:: AbstractRange ) where T = convert (T,first (a)): step (a): convert (T,last (a))
78
79
80
+
81
+ # TODO : replacing with the following will support more general broadcasting.
82
+ # function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange)
83
+ # broadcast_shape(axes(a), axes(b)) # check axes
84
+ # r1 = b[1] * getindex_value(a)
85
+ # T = typeof(r1)
86
+ # if length(b) == 1 # Need a fill, but for type stability use StepRangeLen
87
+ # StepRangeLen{T}(r1, zero(T), length(a))
88
+ # else
89
+ # StepRangeLen{T}(r1, convert(T, getindex_value(a) * step(b)), length(b))
90
+ # end
91
+ # end
92
+
93
+ # function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill)
94
+ # broadcast_shape(axes(a), axes(b)) # check axes
95
+ # r1 = a[1] * getindex_value(b)
96
+ # T = typeof(r1)
97
+ # if length(a) == 1 # Need a fill, but for type stability use StepRangeLen
98
+ # StepRangeLen{T}(r1, zero(T), length(b))
99
+ # else
100
+ # StepRangeLen{T}(r1, convert(T, step(a) * getindex_value(b)), length(a))
101
+ # end
102
+ # end
103
+
79
104
function broadcasted (:: DefaultArrayStyle{1} , :: typeof (* ), a:: Ones{T} , b:: AbstractRange{V} ) where {T,V}
80
- broadcast_shape (axes (a), axes (b)) # Check sizes are compatible.
105
+ broadcast_shape (axes (a), axes (b)) == axes (b) || throw ( ArgumentError ( " Cannot broadcast $a and $b . Convert $b to a Vector first. " ))
81
106
return _range_convert (AbstractVector{promote_type (T,V)}, b)
82
107
end
83
108
84
109
function broadcasted (:: DefaultArrayStyle{1} , :: typeof (* ), a:: AbstractRange{V} , b:: Ones{T} ) where {T,V}
85
- broadcast_shape (axes (a), axes (b)) # Check sizes are compatible.
110
+ broadcast_shape (axes (a), axes (b)) == axes (a) || throw ( ArgumentError ( " Cannot broadcast $a and $b . Convert $b to a Vector first. " ))
86
111
return _range_convert (AbstractVector{promote_type (T,V)}, a)
87
112
end
88
113
89
114
90
115
function broadcasted (:: DefaultArrayStyle{1} , :: typeof (* ), a:: AbstractFill , b:: AbstractRange )
91
- broadcast_shape (axes (a), axes (b)) # Check sizes are compatible.
116
+ broadcast_shape (axes (a), axes (b)) == axes (b) || throw ( ArgumentError ( " Cannot broadcast $a and $b . Convert $b to a Vector first. " ))
92
117
return broadcasted (* , getindex_value (a), b)
93
118
end
94
119
95
120
function broadcasted (:: DefaultArrayStyle{1} , :: typeof (* ), a:: AbstractRange , b:: AbstractFill )
96
- broadcast_shape (axes (a), axes (b)) # Check sizes are compatible.
121
+ broadcast_shape (axes (a), axes (b)) == axes (a) || throw ( ArgumentError ( " Cannot broadcast $a and $b . Convert $b to a Vector first. " ))
97
122
return broadcasted (* , a, getindex_value (b))
98
123
end
99
124
0 commit comments