diff --git a/Project.toml b/Project.toml index 4ce983bf5..42d2d3f9d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.190.4" +version = "6.190.5" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index 3284042e3..ae5fd322b 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -9,10 +9,10 @@ module DiffEqBaseEnzymeExt function Enzyme.EnzymeRules.augmented_primal( config::Enzyme.EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, prob, + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg::Union{ Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, - u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}} + u0, p, args...; kwargs...) where {RT} @inline function copy_or_reuse(val, idx) if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) return deepcopy(val) @@ -31,18 +31,18 @@ module DiffEqBaseEnzymeExt SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; kwargs...) - ResType = typeof(res[1]) - dres = Enzyme.make_zero(res[1])::ResType + dres = Enzyme.make_zero(res[1])::RT tup = (dres, res[2]) - return Enzyme.EnzymeRules.AugmentedReturn{ResType, ResType, Any}(res[1], dres, tup::Any) + return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) end function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, tape, prob, + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, sensealg::Union{ Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, - u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}} + u0, p, args...; kwargs...) where {RT} dres, clos = tape + dres = dres::RT dargs = clos(dres) for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa Enzyme.Const