Skip to content

Commit 06f3672

Browse files
chore: add getindex adjoint for NonlinearSolution
1 parent f5b6ba7 commit 06f3672

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,36 @@ function ∇responsible_map(cx, f, args...)
260260
end
261261
end
262262

263+
264+
@adjoint function Base.getindex(VA::SciMLBase.NonlinearSolution, sym)
265+
function NonlinearSolution_getindex_pullback(Δ)
266+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
267+
if is_observed(VA, sym)
268+
f = SII.observed(VA, sym)
269+
p = parameter_values(VA)
270+
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
271+
u = state_values(VA)
272+
t = current_time(VA)
273+
y, back = Zygote.pullback(u, tunables) do u, tunables
274+
_p = repack(tunables)
275+
# @show f.f_oop.(u, Ref(_p), t)
276+
@show f.f_oop(_p, _p)
277+
end
278+
gs = back(Δ)
279+
(gs[1], nothing)
280+
elseif i === nothing
281+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
282+
else
283+
VA = recursivecopy(VA)
284+
recursivefill!(VA, zero(eltype(VA)))
285+
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
286+
copyto!(v, Δ)
287+
(VA, nothing)
288+
end
289+
end
290+
VA[sym], NonlinearSolution_getindex_pullback
291+
end
292+
263293
@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
264294
∇tmap(__context__, f, args...)
265295
end

0 commit comments

Comments
 (0)