Skip to content

Commit 4f47a31

Browse files
committed
enable fully explicit AD method
1 parent 427b1e8 commit 4f47a31

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

src/ParallelKernel/EnzymeExt/autodiff_gpu.jl

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,35 @@ function promote_to_const(args::Vararg{Any,N}) where N
2525
end
2626
end
2727

28-
function ParallelStencil.ParallelKernel.AD.autodiff_deferred!(mode, f, args::Vararg{Any,N}) where N # NOTE: minimal specialization is required to avoid overwriting the default method
29-
f = promote_to_const(f)[1]
30-
args = promote_to_const(args...)
31-
Enzyme.autodiff_deferred(mode, f, Enzyme.Const, args...)
32-
return
33-
end
34-
35-
function ParallelStencil.ParallelKernel.AD.autodiff_deferred_thunk!(mode, f, args::Vararg{Any,N}) where N # NOTE: minimal specialization is required to avoid overwriting the default method
36-
f = promote_to_const(f)[1]
37-
args = promote_to_const(args...)
38-
Enzyme.autodiff_deferred_thunk(mode, f, Enzyme.Const, args...)
39-
return
40-
end
28+
29+
function ParallelStencil.ParallelKernel.AD.autodiff_deferred!(mode, f, ::Type{T}, args::Vararg{Any,N}) where {T<:Enzyme.Annotation, N} # NOTE: minimal specialization is required to avoid overwriting the default method
30+
f = promote_to_const(f)[1]
31+
args = promote_to_const(args...)
32+
Enzyme.autodiff_deferred(mode, f, T, args...)
33+
return
34+
end
35+
36+
function ParallelStencil.ParallelKernel.AD.autodiff_deferred!(mode, f, args::Vararg{Any,N}) where N # NOTE: minimal specialization is required to avoid overwriting the default method
37+
f = promote_to_const(f)[1]
38+
args = promote_to_const(args...)
39+
Enzyme.autodiff_deferred(mode, f, Enzyme.Const, args...)
40+
return
41+
end
42+
43+
44+
function ParallelStencil.ParallelKernel.AD.autodiff_deferred_thunk!(mode, f, ::Type{T}, args::Vararg{Any,N}) where {T<:Enzyme.Annotation, N} # NOTE: minimal specialization is required to avoid overwriting the default method
45+
f = promote_to_const(f)[1]
46+
args = promote_to_const(args...)
47+
Enzyme.autodiff_deferred_thunk(mode, f, T, args...)
48+
return
49+
end
50+
51+
function ParallelStencil.ParallelKernel.AD.autodiff_deferred_thunk!(mode, f, args::Vararg{Any,N}) where N # NOTE: minimal specialization is required to avoid overwriting the default method
52+
f = promote_to_const(f)[1]
53+
args = promote_to_const(args...)
54+
Enzyme.autodiff_deferred_thunk(mode, f, Enzyme.Const, args...)
55+
return
56+
end
4157

4258

4359
## FUNCTIONS TO CHECK EXTENSIONS SUPPORT

0 commit comments

Comments
 (0)