Skip to content

Commit 99bdb1c

Browse files
author
Oscar Smith
authored
Merge branch 'master' into os/ForwardDiff1.0
2 parents c5ee03a + ba51e90 commit 99bdb1c

File tree

4 files changed

+75
-16
lines changed

4 files changed

+75
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "2.79.1"
4+
version = "2.80.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/SciMLBaseZygoteExt.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ function obs_grads(VA, sym, obs_idx, Δ)
116116
back(Δobs)
117117
end
118118

119+
function obs_grads2(VA::SciMLBase.NonlinearSolution, sym, obs_idx, Δ)
120+
y, back = Zygote.pullback(VA) do sol
121+
getindex.(Ref(sol), sym[obs_idx])
122+
end
123+
Δobs = Δ[obs_idx, :]
124+
back(Δobs)
125+
end
126+
119127
function obs_grads(VA, sym, ::Nothing, Δ)
120128
Zygote.nt_nothing(VA)
121129
end
@@ -154,6 +162,31 @@ end
154162
VA[sym], ODESolution_getindex_pullback
155163
end
156164

165+
@adjoint function Base.getindex(VA::SciMLBase.NonlinearSolution, sym)
166+
function NonlinearSolution_getindex_pullback(Δ)
167+
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
168+
if is_observed(VA, sym)
169+
f = observed(VA, sym)
170+
p = parameter_values(VA)
171+
u = state_values(VA)
172+
_, back = Zygote.pullback(u, p) do u, p
173+
f.f_oop(u, p)
174+
end
175+
gs = back(Δ)
176+
((u = gs[1], prob = (p = gs[2],),), nothing)
177+
elseif i === nothing
178+
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
179+
else
180+
VA = recursivecopy(VA)
181+
recursivefill!(VA, zero(eltype(VA)))
182+
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
183+
copyto!(v, Δ)
184+
(VA, nothing)
185+
end
186+
end
187+
VA[sym], NonlinearSolution_getindex_pullback
188+
end
189+
157190
@adjoint function ODESolution{
158191
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u,
159192
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,

src/scimlfunctions.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,8 @@ dt: the time step
948948
949949
```julia
950950
ImplicitDiscreteFunction{iip,specialize}(f;
951-
analytic = __has_analytic(f) ? f.analytic : nothing)
951+
analytic = __has_analytic(f) ? f.analytic : nothing,
952+
resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing)
952953
```
953954
954955
Note that only the function `f` itself is required. This function should
@@ -973,12 +974,13 @@ For more details on this argument, see the ODEFunction documentation.
973974
974975
The fields of the ImplicitDiscreteFunction type directly match the names of the inputs.
975976
"""
976-
struct ImplicitDiscreteFunction{iip, specialize, F, Ta, O, SYS, ID} <:
977+
struct ImplicitDiscreteFunction{iip, specialize, F, Ta, O, SYS, RP, ID} <:
977978
AbstractDiscreteFunction{iip}
978979
f::F
979980
analytic::Ta
980981
observed::O
981982
sys::SYS
983+
resid_prototype::RP
982984
initialization_data::ID
983985
end
984986

@@ -3064,6 +3066,9 @@ function ImplicitDiscreteFunction{iip, specialize}(f;
30643066
observed = __has_observed(f) ?
30653067
f.observed :
30663068
DEFAULT_OBSERVED,
3069+
resid_prototype = __has_resid_prototype(f) ?
3070+
f.resid_prototype :
3071+
nothing,
30673072
sys = __has_sys(f) ? f.sys : nothing,
30683073
initialization_data = __has_initialization_data(f) ? f.initialization_data :
30693074
nothing) where {
@@ -3074,39 +3079,40 @@ function ImplicitDiscreteFunction{iip, specialize}(f;
30743079
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
30753080

30763081
if specialize === NoSpecialize
3077-
ImplicitDiscreteFunction{iip, specialize, Any, Any, Any, Any, Any}(_f,
3082+
ImplicitDiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any}(_f,
30783083
analytic,
30793084
observed,
30803085
sys,
3086+
resid_prototype,
30813087
initialization_data)
30823088
else
30833089
ImplicitDiscreteFunction{
3084-
iip, specialize, typeof(_f), typeof(analytic), typeof(observed), typeof(sys),
3090+
iip, specialize, typeof(_f), typeof(analytic), typeof(observed), typeof(sys), typeof(resid_prototype),
30853091
typeof(initialization_data)}(
3086-
_f, analytic, observed, sys, initialization_data)
3092+
_f, analytic, observed, sys, resid_prototype, initialization_data)
30873093
end
30883094
end
30893095

30903096
function ImplicitDiscreteFunction{iip}(f; kwargs...) where {iip}
30913097
ImplicitDiscreteFunction{iip, FullSpecialize}(f; kwargs...)
30923098
end
30933099
ImplicitDiscreteFunction{iip}(f::ImplicitDiscreteFunction; kwargs...) where {iip} = f
3094-
function ImplicitDiscreteFunction(f; kwargs...)
3095-
ImplicitDiscreteFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...)
3100+
function ImplicitDiscreteFunction(f; resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing, kwargs...)
3101+
ImplicitDiscreteFunction{isinplace(f, 5), FullSpecialize}(f; resid_prototype, kwargs...)
30963102
end
30973103
ImplicitDiscreteFunction(f::ImplicitDiscreteFunction; kwargs...) = f
30983104

30993105
function unwrapped_f(f::ImplicitDiscreteFunction, newf = unwrapped_f(f.f))
31003106
specialize = specialization(f)
31013107

31023108
if specialize === NoSpecialize
3103-
ImplicitDiscreteFunction{isinplace(f, 6), specialize, Any, Any,
3104-
Any, Any, Any}(newf, f.analytic, f.observed, f.sys, f.initialization_data)
3109+
ImplicitDiscreteFunction{isinplace(f, 5), specialize, Any, Any, Any,
3110+
Any, Any, Any}(newf, f.analytic, f.observed, f.sys, f.resid_prototype, f.initialization_data)
31053111
else
3106-
ImplicitDiscreteFunction{isinplace(f, 6), specialize, typeof(newf),
3112+
ImplicitDiscreteFunction{isinplace(f, 5), specialize, typeof(newf),
31073113
typeof(f.analytic),
3108-
typeof(f.observed), typeof(f.sys), typeof(f.initialization_data)}(newf,
3109-
f.analytic, f.observed, f.sys, f.initialization_data)
3114+
typeof(f.observed), typeof(f.sys), typeof(resid_prototype), typeof(f.initialization_data)}(newf,
3115+
f.analytic, f.observed, f.sys, f.resid_prototype, f.initialization_data)
31103116
end
31113117
end
31123118

test/downstream/observables_autodiff.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ sol = solve(prob, Tsit5())
4646
@test du == gs
4747
end
4848

49+
# @testset "AD Observable Functions for Initialization" begin
50+
# iprob = prob.f.initialization_data.initalizeprob
51+
# isol = solve(iprob)
52+
# gs, = gradient(isol) do isol
53+
# isol[w]
54+
# end
55+
56+
# end
57+
4958
# DAE
5059

5160
function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@@ -84,18 +93,29 @@ end
8493
du_ = [0.2, 1.0]
8594
du = [du_ for _ in sol.u]
8695
@test gs == du
96+
97+
@testset "DAE Initialization Observable function AD" begin
98+
iprob = prob.f.initialization_data.initializeprob
99+
isol = solve(iprob)
100+
tunables, repack, _ = SS.canonicalize(SS.Tunable(), SII.parameter_values(iprob))
101+
gs, = gradient(isol) do isol
102+
isol[sys.ampermeter.i]
103+
end
104+
gt = gs.prob.p.tunable
105+
@test length(findall(!iszero, gt)) == 1
106+
end
87107
end
88108

89109
# @testset "Adjoints with DAE" begin
90-
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables
91-
# new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p, new_tunables)
110+
# gs_mtkp, gs_p_new = gradient(prob.p, prob.p.tunable) do p, new_tunables
111+
# new_p = SS.replace(SS.Tunable(), p, new_tunables)
92112
# new_prob = remake(prob, p = new_p)
93113
# sol = solve(new_prob, Rodas4())
94114
# @show size(sol)
95115
# # mean(abs.(sol[sys.ampermeter.i] .- gt))
96116
# sum(sol[sys.ampermeter.i])
97117
# end
98-
#
118+
#
99119
# @test isnothing(gs_mtkp)
100120
# @test length(gs_p_new) == length(p_new)
101121
# end

0 commit comments

Comments
 (0)