@@ -195,20 +195,18 @@ function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::O
195
195
return _range_convert (AbstractVector{TT}, a)
196
196
end
197
197
198
- # To fix AD issues with `broadcast(T, x)`
199
- # Avoids type inference issues with x -> T(x)
200
- struct Constructor{T} end
201
-
202
- function (:: Constructor{T} )(x) where {T}
203
- return T (x)
204
- end
205
-
206
198
for op in (:+ , :- )
207
199
@eval begin
208
200
function broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: AbstractVector , b:: ZerosVector )
209
201
broadcast_shape (axes (a), axes (b)) == axes (a) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $b to a Vector first." ))
210
202
TT = typeof ($ op (zero (eltype (a)), zero (eltype (b))))
211
- eltype (a) === TT ? a : broadcasted (Constructor {TT} (), a)
203
+ # Use `TT ∘ (+)` to fix AD issues with `broadcasted(TT, x)`
204
+ eltype (a) === TT ? a : broadcasted (TT ∘ (+ ), a)
205
+ end
206
+ function broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: ZerosVector , b:: AbstractVector )
207
+ broadcast_shape (axes (a), axes (b)) == axes (b) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $a to a Vector first." ))
208
+ TT = typeof ($ op (zero (eltype (a)), zero (eltype (b))))
209
+ $ op === (+ ) && eltype (b) === TT ? b : broadcasted (TT ∘ ($ op), b)
212
210
end
213
211
214
212
broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: AbstractFillVector , b:: ZerosVector ) =
@@ -219,18 +217,6 @@ for op in (:+, :-)
219
217
end
220
218
end
221
219
222
- function broadcasted (:: DefaultArrayStyle{1} , :: typeof (+ ), a:: ZerosVector , b:: AbstractVector )
223
- broadcast_shape (axes (a), axes (b)) == axes (b) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $a to a Vector first." ))
224
- TT = typeof (zero (eltype (a)) + zero (eltype (b)))
225
- eltype (b) === TT ? b : broadcasted (Constructor {TT} (), b)
226
- end
227
-
228
- function broadcasted (:: DefaultArrayStyle{1} , :: typeof (- ), a:: ZerosVector , b:: AbstractVector )
229
- broadcast_shape (axes (a), axes (b)) == axes (b) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $a to a Vector first." ))
230
- TT = typeof (zero (eltype (a)) - zero (eltype (b)))
231
- broadcasted (TT ∘ (- ), b)
232
- end
233
-
234
220
# Need to prevent array-valued fills from broadcasting over entry
235
221
_broadcast_getindex_value (a:: AbstractFill{<:Number} ) = getindex_value (a)
236
222
_broadcast_getindex_value (a:: AbstractFill ) = Ref (getindex_value (a))
0 commit comments