diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index d785e4ef4..43843b397 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -5,7 +5,7 @@ using Zygote: @adjoint, pullback import Zygote: literal_getproperty import ChainRulesCore using SciMLBase -using SciMLBase: ODESolution, remake, +using SciMLBase: ODESolution, remake, ODEFunction, getobserved, build_solution, EnsembleSolution, NonlinearSolution, AbstractTimeseriesSolution using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, @@ -13,6 +13,14 @@ using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_ using RecursiveArrayTools import SciMLStructures +@adjoint function SciMLBase.remake(prob::ODEFunction; kw...) + y = remake(prob; kw...) + function odefunction_remake_back(Δ) + (Δ,) + end + y, odefunction_remake_back +end + # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt # https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67