@@ -14,21 +14,41 @@ module EnzymeExt
1414
1515 EnzymeRules. inactive (:: Type{StaticSize} , x... ) = nothing
1616
17+ # https://github.com/EnzymeAD/Enzyme.jl/issues/1516
18+ # On the CPU `autodiff_deferred` can deadlock.
1719 function fwd (ctx, f, args... )
1820 EnzymeCore. autodiff_deferred (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
1921 return nothing
2022 end
2123
2224 function aug_fwd (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
2325 TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
24- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
26+ forward, _ = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
2527 subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
2628 return nothing
2729 end
2830
2931 function rev (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
3032 TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
31- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
33+ _, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
34+ tp = subtape[__groupindex (ctx)]
35+ reverse (Const (f), Const (ctx), args... , tp)
36+ return nothing
37+ end
38+
39+ function fwd_cpu (ctx, f, args... )
40+ EnzymeCore. autodiff (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
41+ return nothing
42+ end
43+
44+ function aug_fwd_cpu (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
45+ forward, _ = EnzymeCore. autodiff_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
46+ subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
47+ return nothing
48+ end
49+
50+ function rev_cpu (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
51+ _, reverse = EnzymeCore. autodiff_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
3252 tp = subtape[__groupindex (ctx)]
3353 reverse (Const (f), Const (ctx), args... , tp)
3454 return nothing
@@ -42,6 +62,15 @@ module EnzymeExt
4262 fwd_kernel (f, args... ; ndrange, workgroupsize)
4363 end
4464
65+ function EnzymeRules. forward (func:: Const{<:Kernel{CPU}} , :: Type{Const{Nothing}} , args... ; ndrange= nothing , workgroupsize= nothing )
66+ kernel = func. val
67+ f = kernel. f
68+ fwd_kernel = similar (kernel, fwd_cpu)
69+
70+ fwd_kernel (f, args... ; ndrange, workgroupsize)
71+ end
72+
73+
4574 @inline function make_active_byref (f:: F , :: Val{ActiveTys} ) where {F, ActiveTys}
4675 if ! any (ActiveTys)
4776 return f
@@ -103,7 +132,7 @@ module EnzymeExt
103132
104133 subtape = Array {TapeType} (undef, size (blocks (iterspace)))
105134
106- aug_kernel = similar (kernel, aug_fwd )
135+ aug_kernel = similar (kernel, aug_fwd_cpu )
107136
108137 aug_kernel (f, ModifiedBetween, subtape, args2... ; ndrange, workgroupsize)
109138
@@ -115,7 +144,7 @@ module EnzymeExt
115144 return res
116145 end
117146
118- function EnzymeRules. reverse (config:: Config , func:: Const{<:Kernel} , :: Type{<:EnzymeCore.Annotation} , tape, args:: Vararg{Any, N} ; ndrange= nothing , workgroupsize= nothing ) where N
147+ function EnzymeRules. reverse (config:: Config , func:: Const{<:Kernel{CPU} } , :: Type{<:EnzymeCore.Annotation} , tape, args:: Vararg{Any, N} ; ndrange= nothing , workgroupsize= nothing ) where N
119148 subtape, arg_refs = tape
120149
121150 args2 = ntuple (Val (N)) do i
@@ -138,7 +167,7 @@ module EnzymeExt
138167
139168 ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
140169
141- rev_kernel = similar (func. val, rev )
170+ rev_kernel = similar (func. val, rev_cpu )
142171 rev_kernel (f, ModifiedBetween, subtape, args2... ; ndrange, workgroupsize)
143172 return ntuple (Val (N)) do i
144173 Base. @_inline_meta
0 commit comments