Skip to content

Commit 73be894

Browse files
authored
Merge pull request #926 from SciML/myb/obs
Add observed function for `NonlinearFunction`s
2 parents 24d5041 + 2516a67 commit 73be894

File tree

4 files changed

+65
-17
lines changed

4 files changed

+65
-17
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
158158
eval_expression = true,
159159
sparse = false, simplify=false,
160160
eval_module = @__MODULE__,
161+
steady_state = false,
161162
checkbounds=false,
162163
kwargs...) where {iip}
163164

@@ -194,12 +195,23 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
194195

195196
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
196197

197-
observedfun = let sys = sys, dict = Dict()
198-
function generated_observed(obsvar, u, p, t)
199-
obs = get!(dict, value(obsvar)) do
200-
build_explicit_observed_function(sys, obsvar; checkbounds=checkbounds)
198+
observedfun = if steady_state
199+
let sys = sys, dict = Dict()
200+
function generated_observed(obsvar, u, p, t=Inf)
201+
obs = get!(dict, value(obsvar)) do
202+
build_explicit_observed_function(sys, obsvar)
203+
end
204+
obs(u, p, t)
205+
end
206+
end
207+
else
208+
let sys = sys, dict = Dict()
209+
function generated_observed(obsvar, u, p, t)
210+
obs = get!(dict, value(obsvar)) do
211+
build_explicit_observed_function(sys, obsvar; checkbounds=checkbounds)
212+
end
213+
obs(u, p, t)
201214
end
202-
obs(u, p, t)
203215
end
204216
end
205217

@@ -301,9 +313,30 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
301313
jac = false,
302314
linenumbers = false,
303315
sparse = false, simplify=false,
316+
steady_state = false,
304317
kwargs...) where {iip}
305318

306319
f_oop, f_iip = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)
320+
321+
dict = Dict()
322+
#=
323+
observedfun = if steady_state
324+
:(function generated_observed(obsvar, u, p, t=Inf)
325+
obs = get!($dict, value(obsvar)) do
326+
build_explicit_observed_function($sys, obsvar)
327+
end
328+
obs(u, p, t)
329+
end)
330+
else
331+
:(function generated_observed(obsvar, u, p, t)
332+
obs = get!($dict, value(obsvar)) do
333+
build_explicit_observed_function($sys, obsvar)
334+
end
335+
obs(u, p, t)
336+
end)
337+
end
338+
=#
339+
307340
fsym = gensym(:f)
308341
_f = :($fsym = ModelingToolkit.ODEFunctionClosure($f_oop, $f_iip))
309342
tgradsym = gensym(:tgrad)
@@ -579,7 +612,7 @@ end
579612

580613
"""
581614
```julia
582-
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,tspan,
615+
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,
583616
parammap=DiffEqBase.NullParameters();
584617
version = nothing, tgrad=false,
585618
jac = false,
@@ -593,13 +626,13 @@ symbolically calculating numerical enhancements.
593626
function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem,u0map,
594627
parammap=DiffEqBase.NullParameters();
595628
kwargs...) where iip
596-
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; kwargs...)
597-
SteadyStateProblem(f,u0,p;kwargs...)
629+
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; steady_state = true, kwargs...)
630+
SteadyStateProblem{iip}(f,u0,p;kwargs...)
598631
end
599632

600633
"""
601634
```julia
602-
function DiffEqBase.SteadyStateProblemExpr(sys::AbstractODESystem,u0map,tspan,
635+
function DiffEqBase.SteadyStateProblemExpr(sys::AbstractODESystem,u0map,
603636
parammap=DiffEqBase.NullParameters();
604637
version = nothing, tgrad=false,
605638
jac = false,
@@ -617,7 +650,7 @@ struct SteadyStateProblemExpr{iip} end
617650
function SteadyStateProblemExpr{iip}(sys::AbstractODESystem,u0map,
618651
parammap=DiffEqBase.NullParameters();
619652
kwargs...) where iip
620-
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; kwargs...)
653+
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap;steady_state = true, kwargs...)
621654
linenumbers = get(kwargs, :linenumbers, true)
622655
ex = quote
623656
f = $f

src/systems/diffeqs/odesystem.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,13 @@ function build_explicit_observed_function(
252252
output[i] = obs[idx].rhs
253253
end
254254

255+
dvs = DestructuredArgs(states(sys), inbounds=!checkbounds)
256+
ps = DestructuredArgs(parameters(sys), inbounds=!checkbounds)
257+
iv = independent_variable(sys)
258+
args = iv === nothing ? [dvs, ps] : [dvs, ps, iv]
259+
255260
ex = Func(
256-
[
257-
DestructuredArgs(states(sys), inbounds=!checkbounds)
258-
DestructuredArgs(parameters(sys), inbounds=!checkbounds)
259-
independent_variable(sys)
260-
],
261-
[],
261+
args, [],
262262
Let(
263263
map(eq -> eq.lhseq.rhs, obs[1:maxidx]),
264264
isscalar ? output[1] : MakeArray(output, output_type)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,19 @@ function DiffEqBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sy
142142
_jac = nothing
143143
end
144144

145+
observedfun = let sys = sys, dict = Dict()
146+
function generated_observed(obsvar, u, p)
147+
obs = get!(dict, value(obsvar)) do
148+
build_explicit_observed_function(sys, obsvar)
149+
end
150+
obs(u, p)
151+
end
152+
end
153+
145154
NonlinearFunction{iip}(f,
146155
jac = _jac === nothing ? nothing : _jac,
147156
jac_prototype = sparse ? similar(sys.jac[],Float64) : nothing,
148-
syms = Symbol.(states(sys)))
157+
syms = Symbol.(states(sys)), observed = observedfun)
149158
end
150159

151160
"""

test/reduction.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ u0 = [
147147
prob1 = ODEProblem(reduced_system, u0, (0.0, 100.0), pp)
148148
solve(prob1, Rodas5())
149149

150+
prob2 = SteadyStateProblem(reduced_system, u0, pp)
151+
@test prob2.f.observed(lorenz2.u, prob2.u0, pp) === 1.0
152+
153+
150154
# issue #724 and #716
151155
let
152156
@parameters t
@@ -204,6 +208,8 @@ nlprob = NonlinearProblem(reducedsys, u0, pp)
204208
reducedsol = solve(nlprob, NewtonRaphson())
205209
residual = fill(100.0, length(states(reducedsys)))
206210
nlprob.f(residual, reducedsol.u, pp)
211+
@test hypot(nlprob.f.observed(u2, reducedsol.u, pp), nlprob.f.observed(u1, reducedsol.u, pp)) * pp reducedsol.u atol=1e-9
212+
207213
@test all(x->abs(x) < 1e-5, residual)
208214

209215
N = 5

0 commit comments

Comments
 (0)