From effff586654617612f71608c0f43bc778ecc15b3 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Wed, 9 Apr 2025 15:34:40 +0530 Subject: [PATCH] chore: add remake adjoint for odefunciton --- ext/SciMLBaseZygoteExt.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index d785e4ef4..43843b397 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -5,7 +5,7 @@ using Zygote: @adjoint, pullback import Zygote: literal_getproperty import ChainRulesCore using SciMLBase -using SciMLBase: ODESolution, remake, +using SciMLBase: ODESolution, remake, ODEFunction, getobserved, build_solution, EnsembleSolution, NonlinearSolution, AbstractTimeseriesSolution using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, @@ -13,6 +13,14 @@ using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_ using RecursiveArrayTools import SciMLStructures +@adjoint function SciMLBase.remake(prob::ODEFunction; kw...) + y = remake(prob; kw...) + function odefunction_remake_back(Δ) + (Δ,) + end + y, odefunction_remake_back +end + # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt # https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67