Skip to content

Commit 584f8c3

Browse files
chore: return deref'd tangent for getproperty(ODEProblem)
1 parent d50778b commit 584f8c3

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
module SciMLBaseZygoteExt
22

33
using Zygote
4-
using Zygote: @adjoint, pullback
5-
import Zygote: literal_getproperty
4+
using Zygote: @adjoint, pullback, @_adjoint_keepthunks, _project, pair
5+
import Zygote: literal_getproperty, literal_getfield
66
import ChainRulesCore
77
using SciMLBase
88
using SciMLBase: ODESolution, remake, ODEFunction,
99
getobserved, build_solution, EnsembleSolution,
10-
NonlinearSolution, AbstractTimeseriesSolution
10+
NonlinearSolution, AbstractTimeseriesSolution,
11+
ODEProblem
1112
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed,
1213
observed, parameter_values, state_values, current_time
1314
using RecursiveArrayTools
@@ -299,4 +300,20 @@ end
299300
∇responsible_map(__context__, f, args...)
300301
end
301302

303+
@_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f
304+
val = getfield(x, f)
305+
function back(Δ)
306+
Zygote.accum_param(__context__, val, Δ) === nothing && return
307+
if isimmutable(x)
308+
dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...)
309+
(_project(x, dx), nothing)
310+
else
311+
dx = Zygote.grad_mut(__context__, x)
312+
dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...)
313+
return (dx[],nothing)
314+
end
315+
end
316+
Zygote.unwrap(val), back
317+
end
318+
302319
end

0 commit comments

Comments
 (0)