|
1 | 1 | using LinearAlgebra
|
2 | 2 |
|
3 |
| -using ModelingToolkit: isdifferenceeq, process_events |
| 3 | +using ModelingToolkit: isdifferenceeq, process_events, get_preprocess_constants |
4 | 4 |
|
5 | 5 | const MAX_INLINE_NLSOLVE_SIZE = 8
|
6 | 6 |
|
@@ -187,12 +187,15 @@ function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDic
|
187 | 187 |
|
188 | 188 | fname = gensym("fun")
|
189 | 189 | # f is the function to find roots on
|
| 190 | + funex = isscalar ? rhss[1] : MakeArray(rhss, SVector) |
| 191 | + @show funex |
| 192 | + pre = get_preprocess_constants(funex) |
190 | 193 | f = Func([DestructuredArgs(vars, inbounds = !checkbounds)
|
191 | 194 | DestructuredArgs(params, inbounds = !checkbounds)],
|
192 | 195 | [],
|
193 |
| - Let(needed_assignments[inner_idxs], |
194 |
| - isscalar ? rhss[1] : MakeArray(rhss, SVector), |
195 |
| - false)) |> SymbolicUtils.Code.toexpr |
| 196 | + pre(Let(needed_assignments[inner_idxs], |
| 197 | + funex, |
| 198 | + false))) |> SymbolicUtils.Code.toexpr |
196 | 199 |
|
197 | 200 | # solver call contains code to call the root-finding solver on the function f
|
198 | 201 | solver_call = LiteralExpr(quote
|
@@ -294,15 +297,17 @@ function build_torn_function(sys;
|
294 | 297 | syms = map(Symbol, states)
|
295 | 298 |
|
296 | 299 | pre = get_postprocess_fbody(sys)
|
| 300 | + cpre = get_preprocess_constants(rhss) |
| 301 | + pre2 = x -> pre(cpre(x)) |
297 | 302 |
|
298 | 303 | expr = SymbolicUtils.Code.toexpr(Func([out
|
299 | 304 | DestructuredArgs(states,
|
300 |
| - inbounds = !checkbounds) |
| 305 | + inbounds = !checkbounds) |
301 | 306 | DestructuredArgs(parameters(sys),
|
302 |
| - inbounds = !checkbounds) |
| 307 | + inbounds = !checkbounds) |
303 | 308 | independent_variables(sys)],
|
304 | 309 | [],
|
305 |
| - pre(Let([torn_expr; |
| 310 | + pre2(Let([torn_expr; |
306 | 311 | assignments[is_not_prepended_assignment]],
|
307 | 312 | funbody,
|
308 | 313 | false))),
|
@@ -469,12 +474,13 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
|
469 | 474 | push!(subs, sym ← obs[eqidx].rhs)
|
470 | 475 | end
|
471 | 476 | pre = get_postprocess_fbody(sys)
|
472 |
| - |
| 477 | + cpre = get_preprocess_constants([obs[1:maxidx]; isscalar ? ts[1] : MakeArray(ts, output_type) ]) |
| 478 | + pre2 = x -> pre(cpre(x)) |
473 | 479 | ex = Code.toexpr(Func([DestructuredArgs(solver_states, inbounds = !checkbounds)
|
474 | 480 | DestructuredArgs(parameters(sys), inbounds = !checkbounds)
|
475 | 481 | independent_variables(sys)],
|
476 | 482 | [],
|
477 |
| - pre(Let([collect(Iterators.flatten(solves)) |
| 483 | + pre2(Let([collect(Iterators.flatten(solves)) |
478 | 484 | assignments[is_not_prepended_assignment]
|
479 | 485 | map(eq -> eq.lhs ← eq.rhs, obs[1:maxidx])
|
480 | 486 | subs],
|
|
0 commit comments