Skip to content

Commit eae18e0

Browse files
Merge pull request #952 from DhairyaLGandhi/dg/nonlinear
Adjoint through `getindex(::NonlinearSolution, sym)`
2 parents 0329949 + 25bcea9 commit eae18e0

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ function obs_grads(VA, sym, obs_idx, Δ)
116116
back(Δobs)
117117
end
118118

119+
function obs_grads2(VA::SciMLBase.NonlinearSolution, sym, obs_idx, Δ)
120+
y, back = Zygote.pullback(VA) do sol
121+
getindex.(Ref(sol), sym[obs_idx])
122+
end
123+
Δobs = Δ[obs_idx, :]
124+
back(Δobs)
125+
end
126+
119127
function obs_grads(VA, sym, ::Nothing, Δ)
120128
Zygote.nt_nothing(VA)
121129
end
@@ -154,6 +162,31 @@ end
154162
VA[sym], ODESolution_getindex_pullback
155163
end
156164

165+
@adjoint function Base.getindex(VA::SciMLBase.NonlinearSolution, sym)
166+
function NonlinearSolution_getindex_pullback(Δ)
167+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
168+
if is_observed(VA, sym)
169+
f = observed(VA, sym)
170+
p = parameter_values(VA)
171+
u = state_values(VA)
172+
_, back = Zygote.pullback(u, p) do u, p
173+
f.f_oop(u, p)
174+
end
175+
gs = back(Δ)
176+
((u = gs[1], prob = (p = gs[2],),), nothing)
177+
elseif i === nothing
178+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
179+
else
180+
VA = recursivecopy(VA)
181+
recursivefill!(VA, zero(eltype(VA)))
182+
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
183+
copyto!(v, Δ)
184+
(VA, nothing)
185+
end
186+
end
187+
VA[sym], NonlinearSolution_getindex_pullback
188+
end
189+
157190
@adjoint function ODESolution{
158191
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u,
159192
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,

test/downstream/observables_autodiff.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ sol = solve(prob, Tsit5())
4646
@test du == gs
4747
end
4848

49+
# @testset "AD Observable Functions for Initialization" begin
50+
# iprob = prob.f.initialization_data.initalizeprob
51+
# isol = solve(iprob)
52+
# gs, = gradient(isol) do isol
53+
# isol[w]
54+
# end
55+
56+
# end
57+
4958
# DAE
5059

5160
function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@@ -84,18 +93,29 @@ end
8493
du_ = [0.2, 1.0]
8594
du = [du_ for _ in sol.u]
8695
@test gs == du
96+
97+
@testset "DAE Initialization Observable function AD" begin
98+
iprob = prob.f.initialization_data.initializeprob
99+
isol = solve(iprob)
100+
tunables, repack, _ = SS.canonicalize(SS.Tunable(), SII.parameter_values(iprob))
101+
gs, = gradient(isol) do isol
102+
isol[sys.ampermeter.i]
103+
end
104+
gt = gs.prob.p.tunable
105+
@test length(findall(!iszero, gt)) == 1
106+
end
87107
end
88108

89109
# @testset "Adjoints with DAE" begin
90-
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables
91-
# new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p, new_tunables)
110+
# gs_mtkp, gs_p_new = gradient(prob.p, prob.p.tunable) do p, new_tunables
111+
# new_p = SS.replace(SS.Tunable(), p, new_tunables)
92112
# new_prob = remake(prob, p = new_p)
93113
# sol = solve(new_prob, Rodas4())
94114
# @show size(sol)
95115
# # mean(abs.(sol[sys.ampermeter.i] .- gt))
96116
# sum(sol[sys.ampermeter.i])
97117
# end
98-
#
118+
#
99119
# @test isnothing(gs_mtkp)
100120
# @test length(gs_p_new) == length(p_new)
101121
# end

0 commit comments

Comments
 (0)