Skip to content

Commit b07e2f5

Browse files
wsmosesavik-pal
authored andcommitted
Update TracedRArray.jl
1 parent 1b3a626 commit b07e2f5

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/TracedRArray.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,23 +1281,25 @@ function scan_impl!(
12811281

12821282
dims > ndims(input) && return copyto!(output, input)
12831283

1284-
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
1285-
op_in_T === Union{} && (op_in_T = T)
12861284

12871285
if init === nothing
1286+
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
1287+
op_in_T === Union{} && (op_in_T = T)
12881288
init = __default_init(T, op)
1289+
if typeof(init) != op_in_T
1290+
op_in_T = typeof(init)
1291+
input = typeof(init).(input)
1292+
end
12891293
else
1290-
initT = __default_init(T, op)
12911294
# TODO: fix this for TPUs
1292-
if initT != init && contains(string(first(Reactant.devices())), "TPU")
1293-
throw(AssertionError("Currently, `init` is not supported on TPUs."))
1295+
if contains(string(first(Reactant.devices())), "TPU")
1296+
initT = __default_init(T, op)
1297+
if initT != init
1298+
throw(AssertionError("Currently, `init` is not supported on TPUs."))
1299+
end
12941300
end
12951301
end
12961302

1297-
if typeof(init) != op_in_T
1298-
op_in_T = typeof(init)
1299-
input = typeof(init).(input)
1300-
end
13011303
init = something(init) # unwrap Some
13021304
init = TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(init)}, init)
13031305

0 commit comments

Comments
 (0)