Skip to content

Commit 39c4813

Browse files
wsmosesavik-pal
authored andcommitted
Update TracedRArray.jl
1 parent b6bf77c commit 39c4813

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/TracedRArray.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,23 +1282,25 @@ function scan_impl!(
12821282

12831283
dims > ndims(input) && return copyto!(output, input)
12841284

1285-
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
1286-
op_in_T === Union{} && (op_in_T = T)
12871285

12881286
if init === nothing
1287+
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
1288+
op_in_T === Union{} && (op_in_T = T)
12891289
init = __default_init(T, op)
1290+
if typeof(init) != op_in_T
1291+
op_in_T = typeof(init)
1292+
input = typeof(init).(input)
1293+
end
12901294
else
1291-
initT = __default_init(T, op)
12921295
# TODO: fix this for TPUs
1293-
if initT != init && contains(string(first(Reactant.devices())), "TPU")
1294-
throw(AssertionError("Currently, `init` is not supported on TPUs."))
1296+
if contains(string(first(Reactant.devices())), "TPU")
1297+
initT = __default_init(T, op)
1298+
if initT != init
1299+
throw(AssertionError("Currently, `init` is not supported on TPUs."))
1300+
end
12951301
end
12961302
end
12971303

1298-
if typeof(init) != op_in_T
1299-
op_in_T = typeof(init)
1300-
input = typeof(init).(input)
1301-
end
13021304
init = something(init) # unwrap Some
13031305
init = TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(init)}, init)
13041306

0 commit comments

Comments
 (0)