Skip to content

Commit 4e0781f

Browse files
YingboMashashi
andcommitted
Convert observed equations as assignments when lowering
Co-authored-by: "Shashi Gowda" <[email protected]>
1 parent 23ab27a commit 4e0781f

File tree

3 files changed

+51
-10
lines changed

3 files changed

+51
-10
lines changed

src/build_function.jl

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,36 @@ function unflatten_long_ops(op, N=4)
9797
Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op)
9898
end
9999

100+
struct Let
101+
eqs::Vector
102+
body
103+
end
104+
105+
function observed_let(eqs)
106+
process -> ex -> begin
107+
lhss = map(eq->process(eq.lhs), eqs)
108+
rhss = map(eq->process(eq.rhs), eqs)
109+
letexpr = Expr(:let)
110+
assignments = quote end
111+
for (l, r) in zip(lhss, rhss)
112+
push!(assignments.args, :($l = $r))
113+
end
114+
push!(letexpr.args, assignments)
115+
push!(letexpr.args, ex)
116+
letexpr
117+
end
118+
end
119+
120+
function _build_function(target::JuliaTarget, op::Let, args...; conv=toexpr, kw...)
121+
_build_function(target, op.body, args...;
122+
inner_let = observed_let(op.eqs), kw...)
123+
end
124+
100125
# Scalar output
101126
function _build_function(target::JuliaTarget, op, args...;
102127
conv = toexpr, expression = Val{true},
103128
checkbounds = false,
129+
inner_let = nothing,
104130
linenumbers = true, headerfun=addheader)
105131

106132
argnames = [gensym(:MTKArg) for i in 1:length(args)]
@@ -109,12 +135,18 @@ function _build_function(target::JuliaTarget, op, args...;
109135
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
110136
ls = reduce(vcat,conv.(first.(arg_pairs)))
111137
rs = reduce(vcat,last.(arg_pairs))
112-
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, conv.(process.(rs))))
138+
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, conv.(process.(rs))))
139+
140+
if inner_let !== nothing
141+
inner_let_expr = inner_let(conv process)
142+
else
143+
inner_let_expr = identity
144+
end
113145

114146
fname = gensym(:ModelingToolkitFunction)
115147
op = process(op)
116148
out_expr = conv(substitute(op, symsdict, fold=false))
117-
let_expr = Expr(:let, var_eqs, Expr(:block, out_expr))
149+
let_expr = Expr(:let, var_eqs, Expr(:block, inner_let_expr(out_expr)))
118150
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
119151

120152
fargs = Expr(:tuple,argnames...)
@@ -218,6 +250,7 @@ Special Keyword Argumnets:
218250
"""
219251
function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
220252
conv = toexpr, expression = Val{true},
253+
inner_let = nothing,
221254
checkbounds = false,
222255
linenumbers = false, multithread=nothing,
223256
headerfun = addheader, outputidxs=nothing,
@@ -235,6 +268,12 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
235268
arg_pairs = map((x,y)->vars_to_pairs(x,y, symsdict), argnames, args)
236269
process = unflatten_long_ops(x->substitute(x, symsdict, fold=false))
237270

271+
if inner_let !== nothing
272+
inner_let_expr = inner_let(conv process)
273+
else
274+
inner_let_expr = identity
275+
end
276+
238277
ls = reduce(vcat,conv.(first.(arg_pairs)))
239278
rs = reduce(vcat,last.(arg_pairs))
240279
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, conv.(process.(rs))))
@@ -440,9 +479,10 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
440479
end
441480
end : arr_sys_expr
442481

443-
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
444-
arr_let_expr = Expr(:let, var_eqs, arr_sys_expr)
445-
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
482+
arr_let_expr = Expr(:let, var_eqs, inner_let_expr(arr_sys_expr))
483+
idx = findfirst(x->Meta.isexpr(x, :let), ip_let_expr.args)
484+
ip_let_expr.args[idx].args[2] = inner_let_expr(ip_let_expr.args[idx].args[2])
485+
446486
oop_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
447487
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
448488

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,18 @@ function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = param
6868
end
6969

7070
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
71+
dvs = vcat(dvs, map(eq->eq.lhs, observed(sys)))
7172
# optimization
7273
dvs′ = makesym.(value.(dvs), states=dvs)
7374
ps′ = makesym.(value.(ps), states=dvs)
7475

7576
sub = Dict(dvs .=> dvs′)
7677
# substitute x(t) by just x
7778
rhss = [substitute(deq.rhs, sub) for deq equations(sys)]
78-
return build_function(rhss, dvs′, ps′, sys.iv;
79+
obss = [makesym(value(eq.lhs)) ~ substitute(eq.rhs, sub) for eq observed(sys)]
80+
81+
# TODO: add an optional check on the ordering of observed equations
82+
return build_function(Let(obss, rhss), dvs′, ps′, sys.iv;
7983
conv = ODEToExpr(sys),kwargs...)
8084
end
8185

src/systems/reduction.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,12 @@ function alias_elimination(sys::ODESystem)
107107
end
108108
end
109109

110-
eqs′ = substitute_aliases(neweqs, Dict(subs))
111-
112110
alias_vars = first.(subs)
113111
sys_states = states(sys)
114112
alias_eqs = alias_vars .~ last.(subs)
115-
#alias_eqs = topsort_equations(alias_eqs, sys_states)
116113

117114
newstates = setdiff(sys_states, alias_vars)
118-
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_eqs)
115+
ODESystem(neweqs, sys.iv, newstates, parameters(sys), observed=alias_eqs)
119116
end
120117

121118
"""

0 commit comments

Comments
 (0)