Skip to content

Commit 3de9f21

Browse files
authored
Add conditional check to index of context in gpu augmented forward and reverse in EnzymeExt
1 parent 91ada95 commit 3de9f21

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

ext/EnzymeExt.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,10 @@ function gpu_aug_fwd(
219219

220220
# On the GPU: F is a per thread function
221221
# On the GPU: subtape::Vector
222-
I = __index_Global_Linear(ctx)
223-
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
222+
if __validindex(ctx)
223+
I = __index_Global_Linear(ctx)
224+
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
225+
end
224226
return nothing
225227
end
226228

@@ -241,9 +243,11 @@ function gpu_rev(
241243
Const{Core.Typeof(ctx)},
242244
map(Core.Typeof, args)...,
243245
)
244-
I = __index_Global_Linear(ctx)
245-
tp = subtape[I]
246-
reverse(Const(f), Const(ctx), args..., tp)
246+
if __validindex(ctx)
247+
I = __index_Global_Linear(ctx)
248+
tp = subtape[I]
249+
reverse(Const(f), Const(ctx), args..., tp)
250+
end
247251
return nothing
248252
end
249253

0 commit comments

Comments
 (0)