Skip to content

Commit 8641bcf

Browse files
committed
Add observe fun for SteadyStateProblem
1 parent 972835b commit 8641bcf

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
145145
eval_expression = true,
146146
sparse = false, simplify=false,
147147
eval_module = @__MODULE__,
148+
steady_state = false,
148149
kwargs...) where {iip}
149150

150151
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, kwargs...)
@@ -178,12 +179,23 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
178179

179180
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
180181

181-
observedfun = let sys = sys, dict = Dict()
182-
function generated_observed(obsvar, u, p, t)
183-
obs = get!(dict, value(obsvar)) do
184-
build_explicit_observed_function(sys, obsvar)
182+
observedfun = if steady_state
183+
let sys = sys, dict = Dict()
184+
function generated_observed(obsvar, u, p, t=Inf)
185+
obs = get!(dict, value(obsvar)) do
186+
build_explicit_observed_function(sys, obsvar)
187+
end
188+
obs(u, p, t)
189+
end
190+
end
191+
else
192+
let sys = sys, dict = Dict()
193+
function generated_observed(obsvar, u, p, t)
194+
obs = get!(dict, value(obsvar)) do
195+
build_explicit_observed_function(sys, obsvar)
196+
end
197+
obs(u, p, t)
185198
end
186-
obs(u, p, t)
187199
end
188200
end
189201

@@ -228,9 +240,28 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
228240
jac = false,
229241
linenumbers = false,
230242
sparse = false, simplify=false,
243+
steady_state = false,
231244
kwargs...) where {iip}
232245

233246
f_oop, f_iip = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)
247+
248+
dict = Dict()
249+
observedfun = if steady_state
250+
:(function generated_observed(obsvar, u, p, t=Inf)
251+
obs = get!($dict, value(obsvar)) do
252+
build_explicit_observed_function($sys, obsvar)
253+
end
254+
obs(u, p, t)
255+
end)
256+
else
257+
:(function generated_observed(obsvar, u, p, t)
258+
obs = get!($dict, value(obsvar)) do
259+
build_explicit_observed_function($sys, obsvar)
260+
end
261+
obs(u, p, t)
262+
end)
263+
end
264+
234265
fsym = gensym(:f)
235266
_f = :($fsym = ModelingToolkit.ODEFunctionClosure($f_oop, $f_iip))
236267
tgradsym = gensym(:tgrad)
@@ -271,6 +302,7 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
271302
jac_prototype = $jp_expr,
272303
syms = $(Symbol.(states(sys))),
273304
indepsym = $(QuoteNode(Symbol(independent_variable(sys)))),
305+
observed = $observedfun,
274306
)
275307
end
276308
!linenumbers ? striplines(ex) : ex
@@ -379,7 +411,7 @@ end
379411

380412
"""
381413
```julia
382-
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,tspan,
414+
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,
383415
parammap=DiffEqBase.NullParameters();
384416
version = nothing, tgrad=false,
385417
jac = false,
@@ -393,13 +425,13 @@ symbolically calculating numerical enhancements.
393425
function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem,u0map,
394426
parammap=DiffEqBase.NullParameters();
395427
kwargs...) where iip
396-
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; kwargs...)
397-
SteadyStateProblem(f,u0,p;kwargs...)
428+
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; steady_state = true, kwargs...)
429+
SteadyStateProblem{iip}(f,u0,p;kwargs...)
398430
end
399431

400432
"""
401433
```julia
402-
function DiffEqBase.SteadyStateProblemExpr(sys::AbstractODESystem,u0map,tspan,
434+
function DiffEqBase.SteadyStateProblemExpr(sys::AbstractODESystem,u0map,
403435
parammap=DiffEqBase.NullParameters();
404436
version = nothing, tgrad=false,
405437
jac = false,
@@ -417,7 +449,7 @@ struct SteadyStateProblemExpr{iip} end
417449
function SteadyStateProblemExpr{iip}(sys::AbstractODESystem,u0map,
418450
parammap=DiffEqBase.NullParameters();
419451
kwargs...) where iip
420-
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; kwargs...)
452+
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap;steady_state = true, kwargs...)
421453
linenumbers = get(kwargs, :linenumbers, true)
422454
ex = quote
423455
f = $f

test/reduction.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ u0 = [
143143
prob1 = ODEProblem(reduced_system, u0, (0.0, 100.0), pp)
144144
solve(prob1, Rodas5())
145145

146+
prob2 = SteadyStateProblem(reduced_system, u0, pp)
147+
@test prob2.f.observed(lorenz2.u, prob2.u0, pp) === 1.0
148+
149+
146150
# issue #724 and #716
147151
let
148152
@parameters t

0 commit comments

Comments
 (0)