@@ -2,14 +2,13 @@ module DiffEqBaseEnzymeExt
2
2
3
3
using DiffEqBase
4
4
import DiffEqBase: value
5
- isdefined (Base, :get_extension ) ? ( import Enzyme) : ( import .. Enzyme)
6
-
5
+ using Enzyme
6
+ import Enzyme : Const
7
7
using ChainRulesCore
8
- using EnzymeCore
9
8
10
- function EnzymeCore . EnzymeRules. augmented_primal (config:: EnzymeCore .EnzymeRules.ConfigWidth{1} , func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT}} , prob, sensealg:: Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}} , u0, p, args... ; kwargs... ) where RT
9
+ function Enzyme . EnzymeRules. augmented_primal (config:: Enzyme .EnzymeRules.ConfigWidth{1} , func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT}} , prob, sensealg:: Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}} , u0, p, args... ; kwargs... ) where RT
11
10
@inline function copy_or_reuse (val, idx)
12
- if EnzymeCore . EnzymeRules. overwritten (config)[idx] && ismutable (val)
11
+ if Enzyme . EnzymeRules. overwritten (config)[idx] && ismutable (val)
13
12
return deepcopy (val)
14
13
else
15
14
return val
@@ -28,15 +27,15 @@ function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.
28
27
v.= 0
29
28
end
30
29
tup = (dres, res[2 ])
31
- return EnzymeCore . EnzymeRules. AugmentedReturn {RT, RT, Any} (res[1 ], dres, tup:: Any )
30
+ return Enzyme . EnzymeRules. AugmentedReturn {RT, RT, Any} (res[1 ], dres, tup:: Any )
32
31
end
33
32
34
- function EnzymeCore . EnzymeRules . reverse (config:: EnzymeCore .EnzymeRules.ConfigWidth{1} , func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{<:Duplicated{RT}} , tape, prob, sensealg:: Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}} , u0, p, args... ; kwargs... ) where RT
33
+ function Enzyme . reverse (config:: Enzyme .EnzymeRules.ConfigWidth{1} , func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{<:Duplicated{RT}} , tape, prob, sensealg:: Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}} , u0, p, args... ; kwargs... ) where RT
35
34
dres, clos = tape
36
35
dres = dres:: RT
37
36
dargs = clos (dres)
38
37
for (darg, ptr) in zip (dargs, (func, prob, sensealg, u0, p, args... ))
39
- if ptr isa EnzymeCore . Const
38
+ if ptr isa Enzyme . Const
40
39
continue
41
40
end
42
41
if darg == ChainRulesCore. NoTangent ()
0 commit comments