Skip to content

Commit 081687d

Browse files
committed
Fix tests
1 parent ad53183 commit 081687d

File tree

9 files changed

+65
-54
lines changed

9 files changed

+65
-54
lines changed

src/structural_transformation/codegen.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ function build_torn_function(
249249
states_idxs = collect(diffvars_range(s))
250250
mass_matrix_diag = ones(length(states_idxs))
251251

252-
assignments, deps, bf_states = tearing_assignments(sys)
252+
assignments, deps, sol_states = tearing_assignments(sys)
253253
invdeps = map(_->BitSet(), deps)
254254
for (i, d) in enumerate(deps)
255255
for a in d
@@ -310,16 +310,16 @@ function build_torn_function(
310310
funbody
311311
))
312312
),
313-
bf_states
313+
sol_states
314314
)
315315
if expression
316316
expr, states
317317
else
318-
observedfun = let sys=sys, dict=Dict(), assignments=assignments, deps=(deps, invdeps), bf_states=bf_states, var2assignment=var2assignment
318+
observedfun = let sys=sys, dict=Dict(), assignments=assignments, deps=(deps, invdeps), sol_states=sol_states, var2assignment=var2assignment
319319
function generated_observed(obsvar, u, p, t)
320320
obs = get!(dict, value(obsvar)) do
321321
build_observed_function(sys, obsvar, var_eq_matching, var_sccs,
322-
assignments, deps, bf_states, var2assignment,
322+
assignments, deps, sol_states, var2assignment,
323323
checkbounds=checkbounds,
324324
)
325325
end
@@ -358,7 +358,7 @@ function build_observed_function(
358358
sys, ts, var_eq_matching, var_sccs,
359359
assignments,
360360
deps,
361-
bf_states,
361+
sol_states,
362362
var2assignment;
363363
expression=false,
364364
output_type=Array,
@@ -445,7 +445,7 @@ function build_observed_function(
445445
],
446446
isscalar ? ts[1] : MakeArray(ts, output_type)
447447
))
448-
), bf_states)
448+
), sol_states)
449449

450450
expression ? ex : @RuntimeGeneratedFunction(ex)
451451
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ function tearing_assignments(sys::AbstractSystem)
6262
if empty_substitutions(sys)
6363
assignments = []
6464
deps = Int[]
65-
bf_states = Code.LazyState()
65+
sol_states = Code.LazyState()
6666
else
6767
subs, deps = get_substitutions(sys)
6868
assignments = [Assignment(eq.lhs, eq.rhs) for eq in subs]
69-
bf_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
69+
sol_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
7070
end
71-
return assignments, deps, bf_states
71+
return assignments, deps, sol_states
7272
end
7373

7474
function solve_equation(eq, var, simplify)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,24 +99,12 @@ function generate_function(
9999
p = map(x->time_varying_as_func(value(x), sys), ps)
100100
t = get_iv(sys)
101101

102-
if empty_substitutions(sys)
103-
bf_states = Code.LazyState()
104-
pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
105-
else
106-
subs, = get_substitutions(sys)
107-
bf_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
108-
if has_difference
109-
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs], ex)
110-
else
111-
process = get_postprocess_fbody(sys)
112-
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs], process(ex))
113-
end
114-
end
102+
pre, sol_states = get_substitutions_and_solved_states(sys, no_postprocess = has_difference)
115103

116104
if implicit_dae
117-
build_function(rhss, ddvs, u, p, t; postprocess_fbody=pre, states=bf_states, kwargs...)
105+
build_function(rhss, ddvs, u, p, t; postprocess_fbody=pre, states=sol_states, kwargs...)
118106
else
119-
build_function(rhss, u, p, t; postprocess_fbody=pre, states=bf_states, kwargs...)
107+
build_function(rhss, u, p, t; postprocess_fbody=pre, states=sol_states, kwargs...)
120108
end
121109
end
122110

src/systems/discrete_system/discrete_system.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ eqs = [D(x) ~ σ*(y-x),
1919
D(y) ~ x*(ρ-z)-y,
2020
D(z) ~ x*y - β*z]
2121
22-
@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]) # or
22+
@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]) # or
2323
@named de = DiscreteSystem(eqs)
2424
```
2525
"""
@@ -59,13 +59,18 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
5959
type: type of the system
6060
"""
6161
connector_type::Any
62-
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, connector_type; checks::Bool = true)
62+
"""
63+
substitutions: substitutions generated by tearing.
64+
"""
65+
substitutions::Any
66+
67+
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, connector_type, substitutions=nothing; checks::Bool = true)
6368
if checks
6469
check_variables(dvs, iv)
6570
check_parameters(ps, iv)
6671
all_dimensionless([dvs;ps;iv;ctrls]) || check_units(discreteEqs)
6772
end
68-
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, connector_type)
73+
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, connector_type, substitutions)
6974
end
7075
end
7176

@@ -103,7 +108,7 @@ function DiscreteSystem(
103108
process_variables!(var_to_name, defaults, dvs′)
104109
process_variables!(var_to_name, defaults, ps′)
105110
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
106-
111+
107112
sysnames = nameof.(systems)
108113
if length(unique(sysnames)) != length(sysnames)
109114
throw(ArgumentError("System names must be unique."))
@@ -187,7 +192,7 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
187192
end
188193

189194
u0 = varmap_to_vars(u0map,dvs; defaults=u0defs)
190-
195+
191196
rhss = [eq.rhs for eq in eqs]
192197
u = dvs
193198
p = varmap_to_vars(parammap,ps; defaults=pdefs)
@@ -226,15 +231,15 @@ function linearize_eqs(sys, eqs=get_eqs(sys); return_max_delay=false)
226231
end
227232

228233
all(length.(unique.(values(state_ops))) .<= 1) || error("Each state should be used with single difference operator.")
229-
234+
230235
dts_gcd = Dict()
231236
for v in keys(dts)
232237
dts_gcd[v] = (length(dts[v]) > 0) ? first(dts[v]) : nothing
233238
end
234239

235240
lin_eqs = [
236241
v(get_iv(sys) - (t)) ~ v(get_iv(sys) - (t-dts_gcd[v]))
237-
for v in unique_states if max_delay[v] > 0 && dts_gcd[v]!==nothing for t in collect(max_delay[v]:(-dts_gcd[v]):0)[1:end-1]
242+
for v in unique_states if max_delay[v] > 0 && dts_gcd[v]!==nothing for t in collect(max_delay[v]:(-dts_gcd[v]):0)[1:end-1]
238243
]
239244
eqs = vcat(eqs, lin_eqs)
240245
end
@@ -256,12 +261,12 @@ function generate_function(
256261
)
257262
eqs = equations(sys)
258263
foreach(check_difference_variables, eqs)
259-
# substitute x(t) by just x
260264
rhss = [eq.rhs for eq in eqs]
261265

262266
u = map(x->time_varying_as_func(value(x), sys), dvs)
263267
p = map(x->time_varying_as_func(value(x), sys), ps)
264268
t = get_iv(sys)
265-
266-
build_function(rhss, u, p, t; kwargs...)
269+
270+
pre, sol_states = get_substitutions_and_solved_states(sys)
271+
build_function(rhss, u, p, t; postprocess_fbody=pre, states=sol_states, kwargs...)
267272
end

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ $(FIELDS)
1313
```julia
1414
using ModelingToolkit
1515
16-
@parameters β γ
16+
@parameters β γ
1717
@variables t S(t) I(t) R(t)
1818
rate₁ = β*S*I
1919
affect₁ = [S ~ S - 1, I ~ I + 1]
@@ -301,7 +301,7 @@ function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
301301
vtoj = nothing; jtov = nothing; jtoj = nothing
302302
end
303303

304-
JumpProblem(prob, aggregator, jset; dep_graph=jtoj, vartojumps_map=vtoj, jumptovars_map=jtov,
304+
JumpProblem(prob, aggregator, jset; dep_graph=jtoj, vartojumps_map=vtoj, jumptovars_map=jtov,
305305
scale_rates=false, nocopy=true, kwargs...)
306306
end
307307

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,16 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
5454
type: type of the system
5555
"""
5656
connector_type::Any
57-
function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connector_type; checks::Bool = true)
57+
"""
58+
substitutions: substitutions generated by tearing.
59+
"""
60+
substitutions::Any
61+
62+
function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connector_type, substitutions=nothing; checks::Bool = true)
5863
if checks
5964
all_dimensionless([states;ps]) ||check_units(eqs)
6065
end
61-
new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connector_type)
66+
new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connector_type, substitutions)
6267
end
6368
end
6469

@@ -123,19 +128,14 @@ end
123128
function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys);
124129
sparse = false, simplify=false, kwargs...)
125130
jac = calculate_jacobian(sys,sparse=sparse, simplify=simplify)
126-
return build_function(jac, vs, ps;
127-
conv = AbstractSysToExpr(sys), kwargs...)
131+
return build_function(jac, vs, ps; kwargs...)
128132
end
129133

130134
function generate_function(sys::NonlinearSystem, dvs = states(sys), ps = parameters(sys); kwargs...)
131-
#obsvars = map(eq->eq.lhs, observed(sys))
132-
#fulldvs = [dvs; obsvars]
133-
134135
rhss = [deq.rhs for deq equations(sys)]
135-
#rhss = Let(obss, rhss)
136+
pre, sol_states = get_substitutions_and_solved_states(sys)
136137

137-
return build_function(rhss, value.(dvs), value.(ps);
138-
conv = AbstractSysToExpr(sys), kwargs...)
138+
return build_function(rhss, value.(dvs), value.(ps); postprocess_fbody=pre, states=sol_states, kwargs...)
139139
end
140140

141141
jacobian_sparsity(sys::NonlinearSystem) =

src/utils.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,20 @@ function empty_substitutions(sys)
432432
subs = get_substitutions(sys)
433433
isnothing(subs) || isempty(last(subs))
434434
end
435+
436+
function get_substitutions_and_solved_states(sys; no_postprocess=false)
437+
if empty_substitutions(sys)
438+
sol_states = Code.LazyState()
439+
pre = no_postprocess ? (ex -> ex) : get_postprocess_fbody(sys)
440+
else
441+
subs, = get_substitutions(sys)
442+
sol_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
443+
if no_postprocess
444+
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs], ex)
445+
else
446+
process = get_postprocess_fbody(sys)
447+
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs], process(ex))
448+
end
449+
end
450+
return pre, sol_states
451+
end

test/reduction.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ io = IOBuffer(); show(io, MIME("text/plain"), lorenz1_aliased); str = String(tak
4444
@test all(s->occursin(s, str), ["lorenz1", "States (2)", "Parameters (3)"])
4545
reduced_eqs = [
4646
D(x) ~ σ*(y - x)
47-
D(y) ~ β + x*- (x - y)) - y
47+
D(y) ~ β + x*- z) - y
4848
]
4949
test_equal.(equations(lorenz1_aliased), reduced_eqs)
5050
@test isempty(setdiff(states(lorenz1_aliased), [x, y, z]))
@@ -84,10 +84,10 @@ __x = x
8484
# Reduced Flattened System
8585

8686
reduced_system = structural_simplify(connected)
87-
reduced_system2 = structural_simplify(structural_simplify(structural_simplify(connected)))
87+
reduced_system2 = structural_simplify(tearing_substitution(structural_simplify(tearing_substitution(structural_simplify(connected)))))
8888

8989
@test isempty(setdiff(states(reduced_system), states(reduced_system2)))
90-
@test isequal(equations(reduced_system), equations(reduced_system2))
90+
@test isequal(equations(tearing_substitution(reduced_system)), equations(reduced_system2))
9191
@test isequal(observed(reduced_system), observed(reduced_system2))
9292
@test setdiff(states(reduced_system), [
9393
s
@@ -155,7 +155,7 @@ let
155155
reduced_sys = structural_simplify(connected)
156156
ref_eqs = [
157157
D(ol.x) ~ ol.a*ol.x + ol.b*ol.u
158-
0 ~ pc.k_P*(ol.c*ol.x + ol.d*ol.u) - ol.u
158+
0 ~ pc.k_P*ol.y - ol.u
159159
]
160160
@test ref_eqs == equations(reduced_sys)
161161
end
@@ -254,7 +254,7 @@ eq = [
254254
@named sys0 = ODESystem(eq, t)
255255
sys = structural_simplify(sys0)
256256
@test length(equations(sys)) == 1
257-
eq = equations(sys)[1]
257+
eq = equations(tearing_substitution(sys))[1]
258258
@test isequal(eq.lhs, 0)
259259
dv25 = ModelingToolkit.value(ModelingToolkit.derivative(eq.rhs, v25))
260260
ddv25 = ModelingToolkit.value(ModelingToolkit.derivative(eq.rhs, D(v25)))

test/structural_transformation/tearing.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ end
122122
# solve for
123123
# 0 = u5 - hypot(sin(u5), hypot(cos(sin(u5)), hypot(sin(u5), cos(sin(u5)))))
124124
tornsys = tearing(sys)
125-
@test isequal(equations(tornsys), [0 ~ u5 + (-1 * hypot(hypot(cos(sin(u5)), hypot(sin(u5), cos(sin(u5)))), sin(u5)))])
125+
@test isequal(equations(tornsys), [0 ~ u5 - hypot(u4, u1)])
126126
prob = NonlinearProblem(tornsys, ones(1))
127127
sol = solve(prob, NewtonRaphson())
128128
@test norm(prob.f(sol.u, sol.prob.p)) < 1e-10
@@ -147,7 +147,7 @@ let (mm, _, _) = ModelingToolkit.aag_bareiss(nlsys)
147147
end
148148

149149
newsys = tearing(nlsys)
150-
@test length(equations(newsys)) == 0
150+
@test length(equations(newsys)) <= 1
151151

152152
###
153153
### DAE system
@@ -163,7 +163,8 @@ eqs = [
163163
]
164164
@named daesys = ODESystem(eqs, t)
165165
newdaesys = tearing(daesys)
166-
@test equations(newdaesys) == [D(x) ~ z; 0 ~ x + sin(z) - p*t]
166+
@test equations(newdaesys) == [D(x) ~ z; 0 ~ y + sin(z) - p*t]
167+
@test equations(tearing_substitution(newdaesys)) == [D(x) ~ z; 0 ~ x + sin(z) - p*t]
167168
@test isequal(states(newdaesys), [x, z])
168169
prob = ODAEProblem(newdaesys, [x=>1.0], (0, 1.0), [p=>0.2])
169170
du = [0.0]; u = [1.0]; pr = 0.2; tt = 0.1

0 commit comments

Comments
 (0)