Skip to content

Commit a6b7149

Browse files
chore: handle NonlinearSolution in getindex(::NonlinearSolution, ::Vector{Num})
1 parent 06f3672 commit a6b7149

File tree

1 file changed

+37
-31
lines changed

1 file changed

+37
-31
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,19 @@ function obs_grads(VA, sym, obs_idx, Δ)
116116
back(Δobs)
117117
end
118118

119+
function obs_grads2(VA::SciMLBase.NonlinearSolution, sym, obs_idx, Δ)
120+
y, back = Zygote.pullback(VA) do sol
121+
getindex.(Ref(sol), sym[obs_idx])
122+
end
123+
Δobs = Δ[obs_idx, :]
124+
back(Δobs)
125+
end
126+
119127
function obs_grads(VA, sym, ::Nothing, Δ)
120128
Zygote.nt_nothing(VA)
121129
end
122130

123-
function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T}
131+
function not_obs_grads(VA::DESolution{T}, sym, not_obss_idx, i, Δ) where {T}
124132
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
125133
map(enumerate(us)) do (u_idx, u)
126134
if u_idx in i
@@ -154,6 +162,34 @@ end
154162
VA[sym], ODESolution_getindex_pullback
155163
end
156164

165+
@adjoint function Base.getindex(VA::SciMLBase.NonlinearSolution, sym)
166+
function NonlinearSolution_getindex_pullback(Δ)
167+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
168+
if is_observed(VA, sym)
169+
f = SII.observed(VA, sym)
170+
p = parameter_values(VA)
171+
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
172+
u = state_values(VA)
173+
t = current_time(VA)
174+
y, back = Zygote.pullback(u, tunables) do u, tunables
175+
_p = repack(tunables)
176+
f.f_oop(_p, _p)
177+
end
178+
gs = back(Δ)
179+
(gs[1], nothing)
180+
elseif i === nothing
181+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
182+
else
183+
VA = recursivecopy(VA)
184+
recursivefill!(VA, zero(eltype(VA)))
185+
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
186+
copyto!(v, Δ)
187+
(VA, nothing)
188+
end
189+
end
190+
VA[sym], NonlinearSolution_getindex_pullback
191+
end
192+
157193
@adjoint function ODESolution{
158194
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u,
159195
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
@@ -260,36 +296,6 @@ function ∇responsible_map(cx, f, args...)
260296
end
261297
end
262298

263-
264-
@adjoint function Base.getindex(VA::SciMLBase.NonlinearSolution, sym)
265-
function NonlinearSolution_getindex_pullback(Δ)
266-
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
267-
if is_observed(VA, sym)
268-
f = SII.observed(VA, sym)
269-
p = parameter_values(VA)
270-
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
271-
u = state_values(VA)
272-
t = current_time(VA)
273-
y, back = Zygote.pullback(u, tunables) do u, tunables
274-
_p = repack(tunables)
275-
# @show f.f_oop.(u, Ref(_p), t)
276-
@show f.f_oop(_p, _p)
277-
end
278-
gs = back(Δ)
279-
(gs[1], nothing)
280-
elseif i === nothing
281-
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
282-
else
283-
VA = recursivecopy(VA)
284-
recursivefill!(VA, zero(eltype(VA)))
285-
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
286-
copyto!(v, Δ)
287-
(VA, nothing)
288-
end
289-
end
290-
VA[sym], NonlinearSolution_getindex_pullback
291-
end
292-
293299
@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
294300
∇tmap(__context__, f, args...)
295301
end

0 commit comments

Comments
 (0)