diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index ae5fd322b..3284042e3 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -9,10 +9,10 @@ module DiffEqBaseEnzymeExt function Enzyme.EnzymeRules.augmented_primal( config::Enzyme.EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, prob, sensealg::Union{ Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, - u0, p, args...; kwargs...) where {RT} + u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}} @inline function copy_or_reuse(val, idx) if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) return deepcopy(val) @@ -31,18 +31,18 @@ module DiffEqBaseEnzymeExt SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; kwargs...) - dres = Enzyme.make_zero(res[1])::RT + ResType = typeof(res[1]) + dres = Enzyme.make_zero(res[1])::ResType tup = (dres, res[2]) - return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) + return Enzyme.EnzymeRules.AugmentedReturn{ResType, ResType, Any}(res[1], dres, tup::Any) end function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, tape, prob, sensealg::Union{ Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, - u0, p, args...; kwargs...) where {RT} + u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}} dres, clos = tape - dres = dres::RT dargs = clos(dres) for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa Enzyme.Const