Skip to content

Commit a977762

Browse files
committed
Fixes & tests for codegen w/ constants.
1 parent 23b98a2 commit a977762

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

src/structural_transformation/codegen.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra
22

3-
using ModelingToolkit: isdifferenceeq, process_events
3+
using ModelingToolkit: isdifferenceeq, process_events, get_preprocess_constants
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

@@ -187,12 +187,15 @@ function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDic
187187

188188
fname = gensym("fun")
189189
# f is the function to find roots on
190+
funex = isscalar ? rhss[1] : MakeArray(rhss, SVector)
191+
@show funex
192+
pre = get_preprocess_constants(funex)
190193
f = Func([DestructuredArgs(vars, inbounds = !checkbounds)
191194
DestructuredArgs(params, inbounds = !checkbounds)],
192195
[],
193-
Let(needed_assignments[inner_idxs],
194-
isscalar ? rhss[1] : MakeArray(rhss, SVector),
195-
false)) |> SymbolicUtils.Code.toexpr
196+
pre(Let(needed_assignments[inner_idxs],
197+
funex,
198+
false))) |> SymbolicUtils.Code.toexpr
196199

197200
# solver call contains code to call the root-finding solver on the function f
198201
solver_call = LiteralExpr(quote
@@ -294,15 +297,17 @@ function build_torn_function(sys;
294297
syms = map(Symbol, states)
295298

296299
pre = get_postprocess_fbody(sys)
300+
cpre = get_preprocess_constants(rhss)
301+
pre2 = x -> pre(cpre(x))
297302

298303
expr = SymbolicUtils.Code.toexpr(Func([out
299304
DestructuredArgs(states,
300-
inbounds = !checkbounds)
305+
inbounds = !checkbounds)
301306
DestructuredArgs(parameters(sys),
302-
inbounds = !checkbounds)
307+
inbounds = !checkbounds)
303308
independent_variables(sys)],
304309
[],
305-
pre(Let([torn_expr;
310+
pre2(Let([torn_expr;
306311
assignments[is_not_prepended_assignment]],
307312
funbody,
308313
false))),
@@ -469,12 +474,13 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
469474
push!(subs, sym obs[eqidx].rhs)
470475
end
471476
pre = get_postprocess_fbody(sys)
472-
477+
cpre = get_preprocess_constants([obs[1:maxidx]; isscalar ? ts[1] : MakeArray(ts, output_type) ])
478+
pre2 = x -> pre(cpre(x))
473479
ex = Code.toexpr(Func([DestructuredArgs(solver_states, inbounds = !checkbounds)
474480
DestructuredArgs(parameters(sys), inbounds = !checkbounds)
475481
independent_variables(sys)],
476482
[],
477-
pre(Let([collect(Iterators.flatten(solves))
483+
pre2(Let([collect(Iterators.flatten(solves))
478484
assignments[is_not_prepended_assignment]
479485
map(eq -> eq.lhs eq.rhs, obs[1:maxidx])
480486
subs],

test/structural_transformation/tearing.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ using UnPack
1010
### Nonlinear system
1111
###
1212
@parameters t
13+
@constants h = 1
1314
@variables u1(t) u2(t) u3(t) u4(t) u5(t)
1415
eqs = [
15-
0 ~ u1 - sin(u5),
16+
0 ~ u1 - sin(u5) * h,
1617
0 ~ u2 - cos(u1),
1718
0 ~ u3 - hypot(u1, u2),
1819
0 ~ u4 - hypot(u2, u3),
@@ -147,13 +148,13 @@ using ModelingToolkit, OrdinaryDiffEq, BenchmarkTools
147148
@parameters t p
148149
@variables x(t) y(t) z(t)
149150
D = Differential(t)
150-
eqs = [D(x) ~ z
151+
eqs = [D(x) ~ z * h
151152
0 ~ x - y
152153
0 ~ sin(z) + y - p * t]
153154
@named daesys = ODESystem(eqs, t)
154155
newdaesys = tearing(daesys)
155-
@test equations(newdaesys) == [D(x) ~ z; 0 ~ y + sin(z) - p * t]
156-
@test equations(tearing_substitution(newdaesys)) == [D(x) ~ z; 0 ~ x + sin(z) - p * t]
156+
@test equations(newdaesys) == [D(x) ~ h * z; 0 ~ y + sin(z) - p * t]
157+
@test equations(tearing_substitution(newdaesys)) == [D(x) ~ h * z; 0 ~ x + sin(z) - p * t]
157158
@test isequal(states(newdaesys), [x, z])
158159
prob = ODAEProblem(newdaesys, [x => 1.0], (0, 1.0), [p => 0.2])
159160
du = [0.0];

0 commit comments

Comments
 (0)