Skip to content

Commit 05a3c05

Browse files
authored
Merge pull request #514 from jlk9/main
Valid index check for gpu in EnzymeExt
2 parents 91ada95 + 80b2996 commit 05a3c05

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

ext/EnzymeExt.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import KernelAbstractions:
2525
__index_Global_Linear,
2626
__groupsize,
2727
__groupindex,
28+
__validindex,
2829
Backend,
2930
synchronize
3031

@@ -219,8 +220,10 @@ function gpu_aug_fwd(
219220

220221
# On the GPU: F is a per thread function
221222
# On the GPU: subtape::Vector
222-
I = __index_Global_Linear(ctx)
223-
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
223+
if __validindex(ctx)
224+
I = __index_Global_Linear(ctx)
225+
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
226+
end
224227
return nothing
225228
end
226229

@@ -241,9 +244,11 @@ function gpu_rev(
241244
Const{Core.Typeof(ctx)},
242245
map(Core.Typeof, args)...,
243246
)
244-
I = __index_Global_Linear(ctx)
245-
tp = subtape[I]
246-
reverse(Const(f), Const(ctx), args..., tp)
247+
if __validindex(ctx)
248+
I = __index_Global_Linear(ctx)
249+
tp = subtape[I]
250+
reverse(Const(f), Const(ctx), args..., tp)
251+
end
247252
return nothing
248253
end
249254

0 commit comments

Comments
 (0)