@@ -1282,23 +1282,25 @@ function scan_impl!(
1282
1282
1283
1283
dims > ndims (input) && return copyto! (output, input)
1284
1284
1285
- op_in_T = Core. Compiler. return_type (op, Tuple{T,T})
1286
- op_in_T === Union{} && (op_in_T = T)
1287
1285
1288
1286
if init === nothing
1287
+ op_in_T = Core. Compiler. return_type (op, Tuple{T,T})
1288
+ op_in_T === Union{} && (op_in_T = T)
1289
1289
init = __default_init (T, op)
1290
+ if typeof (init) != op_in_T
1291
+ op_in_T = typeof (init)
1292
+ input = typeof (init).(input)
1293
+ end
1290
1294
else
1291
- initT = __default_init (T, op)
1292
1295
# TODO : fix this for TPUs
1293
- if initT != init && contains (string (first (Reactant. devices ())), " TPU" )
1294
- throw (AssertionError (" Currently, `init` is not supported on TPUs." ))
1296
+ if contains (string (first (Reactant. devices ())), " TPU" )
1297
+ initT = __default_init (T, op)
1298
+ if initT != init
1299
+ throw (AssertionError (" Currently, `init` is not supported on TPUs." ))
1300
+ end
1295
1301
end
1296
1302
end
1297
1303
1298
- if typeof (init) != op_in_T
1299
- op_in_T = typeof (init)
1300
- input = typeof (init).(input)
1301
- end
1302
1304
init = something (init) # unwrap Some
1303
1305
init = TracedUtils. promote_to (TracedRNumber{unwrapped_eltype (init)}, init)
1304
1306
0 commit comments