Skip to content

Commit b5288ad

Browse files
Merge pull request #984 from DhairyaLGandhi/dg/rem
chore: add remake adjoint for `ODEFunction`
2 parents f6c01b2 + effff58 commit b5288ad

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@ using Zygote: @adjoint, pullback
55
import Zygote: literal_getproperty
66
import ChainRulesCore
77
using SciMLBase
8-
using SciMLBase: ODESolution, remake,
8+
using SciMLBase: ODESolution, remake, ODEFunction,
99
getobserved, build_solution, EnsembleSolution,
1010
NonlinearSolution, AbstractTimeseriesSolution
1111
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
1212
observed, parameter_values, state_values, current_time
1313
using RecursiveArrayTools
1414
import SciMLStructures
1515

16+
@adjoint function SciMLBase.remake(prob::ODEFunction; kw...)
17+
y = remake(prob; kw...)
18+
function odefunction_remake_back(Δ)
19+
(Δ,)
20+
end
21+
y, odefunction_remake_back
22+
end
23+
1624
# This method resolves the ambiguity with the pullback defined in
1725
# RecursiveArrayToolsZygoteExt
1826
# https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67

0 commit comments

Comments
 (0)