@@ -9,10 +9,10 @@ module DiffEqBaseEnzymeExt
99
1010 function Enzyme. EnzymeRules. augmented_primal (
1111 config:: Enzyme.EnzymeRules.RevConfigWidth{1} ,
12- func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT} } , prob,
12+ func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{RT } , prob,
1313 sensealg:: Union {
1414 Const{Nothing}, Const{<: DiffEqBase.AbstractSensitivityAlgorithm }},
15- u0, p, args... ; kwargs... ) where {RT}
15+ u0, p, args... ; kwargs... ) where {RT <: Union{Duplicated, DuplicatedNoNeed} }
1616 @inline function copy_or_reuse (val, idx)
1717 if Enzyme. EnzymeRules. overwritten (config)[idx] && ismutable (val)
1818 return deepcopy (val)
@@ -31,18 +31,18 @@ module DiffEqBaseEnzymeExt
3131 SciMLBase. EnzymeOriginator (), ntuple (arg_copy, Val (length (args)))... ;
3232 kwargs... )
3333
34- dres = Enzyme. make_zero (res[1 ]):: RT
34+ ResType = typeof (res[1 ])
35+ dres = Enzyme. make_zero (res[1 ]):: ResType
3536 tup = (dres, res[2 ])
36- return Enzyme. EnzymeRules. AugmentedReturn {RT, RT , Any} (res[1 ], dres, tup:: Any )
37+ return Enzyme. EnzymeRules. AugmentedReturn {ResType, ResType , Any} (res[1 ], dres, tup:: Any )
3738 end
3839
3940 function Enzyme. EnzymeRules. reverse (config:: Enzyme.EnzymeRules.RevConfigWidth{1} ,
40- func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT} } , tape, prob,
41+ func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{RT } , tape, prob,
4142 sensealg:: Union {
4243 Const{Nothing}, Const{<: DiffEqBase.AbstractSensitivityAlgorithm }},
43- u0, p, args... ; kwargs... ) where {RT}
44+ u0, p, args... ; kwargs... ) where {RT <: Union{Duplicated, DuplicatedNoNeed} }
4445 dres, clos = tape
45- dres = dres:: RT
4646 dargs = clos (dres)
4747 for (darg, ptr) in zip (dargs, (func, prob, sensealg, u0, p, args... ))
4848 if ptr isa Enzyme. Const
0 commit comments