|
1 | 1 | module SciMLBaseZygoteExt |
2 | 2 |
|
3 | 3 | using Zygote |
4 | | -using Zygote: @adjoint, pullback |
5 | | -import Zygote: literal_getproperty |
| 4 | +using Zygote: @adjoint, pullback, @_adjoint_keepthunks, _project, pair |
| 5 | +import Zygote: literal_getproperty, literal_getfield |
6 | 6 | import ChainRulesCore |
7 | 7 | using SciMLBase |
8 | 8 | using SciMLBase: ODESolution, remake, ODEFunction, |
9 | 9 | getobserved, build_solution, EnsembleSolution, |
10 | | - NonlinearSolution, AbstractTimeseriesSolution |
| 10 | + NonlinearSolution, AbstractTimeseriesSolution, |
| 11 | + ODEProblem |
11 | 12 | using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, |
12 | 13 | observed, parameter_values, state_values, current_time |
13 | 14 | using RecursiveArrayTools |
|
299 | 300 | ∇responsible_map(__context__, f, args...) |
300 | 301 | end |
301 | 302 |
|
| 303 | +@_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f |
| 304 | + val = getfield(x, f) |
| 305 | + function back(Δ) |
| 306 | + Zygote.accum_param(__context__, val, Δ) === nothing && return |
| 307 | + if isimmutable(x) |
| 308 | + dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) |
| 309 | + (_project(x, dx), nothing) |
| 310 | + else |
| 311 | + dx = Zygote.grad_mut(__context__, x) |
| 312 | + dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...) |
| 313 | + return (dx[],nothing) |
| 314 | + end |
| 315 | + end |
| 316 | + Zygote.unwrap(val), back |
| 317 | +end |
| 318 | + |
302 | 319 | end |
0 commit comments