Skip to content

Commit effff58

Browse files
chore: add remake adjoint for odefunciton
1 parent 45612a9 commit effff58

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@ using Zygote: @adjoint, pullback
55
import Zygote: literal_getproperty
66
import ChainRulesCore
77
using SciMLBase
8-
using SciMLBase: ODESolution, remake,
8+
using SciMLBase: ODESolution, remake, ODEFunction,
99
getobserved, build_solution, EnsembleSolution,
1010
NonlinearSolution, AbstractTimeseriesSolution
1111
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
1212
observed, parameter_values, state_values, current_time
1313
using RecursiveArrayTools
1414
import SciMLStructures
1515

16+
@adjoint function SciMLBase.remake(prob::ODEFunction; kw...)
17+
y = remake(prob; kw...)
18+
function odefunction_remake_back(Δ)
19+
(Δ,)
20+
end
21+
y, odefunction_remake_back
22+
end
23+
1624
# This method resolves the ambiguity with the pullback defined in
1725
# RecursiveArrayToolsZygoteExt
1826
# https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67

0 commit comments

Comments
 (0)