@@ -1281,23 +1281,25 @@ function scan_impl!(
1281
1281
1282
1282
dims > ndims (input) && return copyto! (output, input)
1283
1283
1284
- op_in_T = Core. Compiler. return_type (op, Tuple{T,T})
1285
- op_in_T === Union{} && (op_in_T = T)
1286
1284
1287
1285
if init === nothing
1286
+ op_in_T = Core. Compiler. return_type (op, Tuple{T,T})
1287
+ op_in_T === Union{} && (op_in_T = T)
1288
1288
init = __default_init (T, op)
1289
+ if typeof (init) != op_in_T
1290
+ op_in_T = typeof (init)
1291
+ input = typeof (init).(input)
1292
+ end
1289
1293
else
1290
- initT = __default_init (T, op)
1291
1294
# TODO : fix this for TPUs
1292
- if initT != init && contains (string (first (Reactant. devices ())), " TPU" )
1293
- throw (AssertionError (" Currently, `init` is not supported on TPUs." ))
1295
+ if contains (string (first (Reactant. devices ())), " TPU" )
1296
+ initT = __default_init (T, op)
1297
+ if initT != init
1298
+ throw (AssertionError (" Currently, `init` is not supported on TPUs." ))
1299
+ end
1294
1300
end
1295
1301
end
1296
1302
1297
- if typeof (init) != op_in_T
1298
- op_in_T = typeof (init)
1299
- input = typeof (init).(input)
1300
- end
1301
1303
init = something (init) # unwrap Some
1302
1304
init = TracedUtils. promote_to (TracedRNumber{unwrapped_eltype (init)}, init)
1303
1305
0 commit comments