Skip to content

Commit e24a388

Browse files
committed
Fix codegen bugs and update tests
1 parent 542afe6 commit e24a388

File tree

6 files changed

+19
-45
lines changed

6 files changed

+19
-45
lines changed

src/structural_transformation/codegen.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,13 @@ function build_observed_function(
419419
if !isempty(subset)
420420
eqs = equations(sys)
421421

422-
torn_eqs = map(i->map(v->eqs[var_eq_matching[v]], var_sccs[i]), subset)
423-
torn_vars = map(i->map(v->fullvars[v], var_sccs[i]), subset)
422+
nested_torn_vars_idxs = []
423+
for iscc in subset
424+
torn_vars_idxs = Int[var for var in var_sccs[iscc] if var_eq_matching[var] !== unassigned]
425+
isempty(torn_vars_idxs) || push!(nested_torn_vars_idxs, torn_vars_idxs)
426+
end
427+
torn_eqs = [[eqs[var_eq_matching[i]] for i in idxs] for idxs in nested_torn_vars_idxs]
428+
torn_vars = [fullvars[idxs] for idxs in nested_torn_vars_idxs]
424429
u0map = defaults(sys)
425430
assignments = copy(assignments)
426431
solves = map(zip(torn_eqs, torn_vars)) do (eqs, vars)

src/structural_transformation/symbolics_tearing.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify=false
153153
# convert it into the mass matrix form.
154154
# We cannot solve the differential variable like D(x)
155155
if isdiffvar(iv)
156-
push!(diffeqs, solve_equation(eqs[ieq], fullvars[iv], simplify))
156+
push!(diffeqs, solve_equation(neweqs[ieq], fullvars[iv], simplify))
157157
continue
158158
end
159159
push!(solved_equations, ieq); push!(solved_variables, iv)
@@ -174,25 +174,17 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify=false
174174
length(dterms) == 0 && return 0 ~ rhs
175175
new_rhs = rhs
176176
new_lhs = 0
177-
nnegative = 0
178177
for iv in dterms
179178
var = fullvars[iv]
179+
# 0 ~ a * D(x) + b
180+
# D(x) ~ -b/a
180181
a, b, islinear = linear_expansion(new_rhs, var)
181182
au = unwrap(a)
182-
if !islinear || (au isa Symbolic) || isinput(var) || !(au isa Number)
183+
if !islinear
183184
return 0 ~ rhs
184185
end
185-
if -au < 0
186-
nnegative += 1
187-
end
188-
new_lhs -= a*var
189-
new_rhs = b
190-
end
191-
# If most of the terms are negative, just multiply through by -1
192-
# to make the equations looks slightly nicer.
193-
if nnegative > div(length(dterms), 2)
194-
new_lhs = -new_lhs
195-
new_rhs = -new_rhs
186+
new_lhs += var
187+
new_rhs = -b/a
196188
end
197189
return new_lhs ~ new_rhs
198190
else # a number

test/reduction.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,11 @@ eq = [
256256
sys = structural_simplify(sys0)
257257
@test length(equations(sys)) == 1
258258
eq = equations(tearing_substitution(sys))[1]
259-
@test isequal(eq.lhs, 0)
259+
@test isequal(eq.lhs, D(v25))
260260
dv25 = ModelingToolkit.value(ModelingToolkit.derivative(eq.rhs, v25))
261-
ddv25 = ModelingToolkit.value(ModelingToolkit.derivative(eq.lhs, D(v25)))
262261
dt = ModelingToolkit.value(ModelingToolkit.derivative(eq.rhs, sin(10t)))
263-
@test dv25 -0.3
264-
@test ddv25 == 0.005
265-
@test dt == 0.1
262+
@test dv25 -60
263+
@test dt 20
266264

267265
# Don't reduce inputs
268266
@parameters t σ ρ β

test/structural_transformation/index_reduction.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,6 @@ state = TearingState(pendulum)
3535
@unpack graph, var_to_diff = state.structure
3636
@test StructuralTransformations.maximal_matching(graph, eq->true, v->var_to_diff[v] === nothing) == map(x -> x == 0 ? StructuralTransformations.unassigned : x, [1, 2, 3, 4, 0, 0, 0, 0, 0])
3737

38-
sys, var_eq_matching, eq_to_diff = StructuralTransformations.pantelides!(pendulum)
39-
state = TearingState(sys)
40-
@unpack graph, var_to_diff = state.structure
41-
@test graph.fadjlist == [[1, 7], [2, 8], [3, 5, 9], [4, 6, 9], [5, 6], [1, 2, 5, 6], [1, 3, 7, 10], [2, 4, 8, 11], [1, 2, 5, 6, 10, 11]]
42-
let N=nothing;
43-
@test var_to_diff == [10, 11, N, N, 1, 2, 3, 4, N, N, N];
44-
#1: D(x) ~ w
45-
#2: D(y) ~ z
46-
#3: D(w) ~ T*x
47-
#4: D(z) ~ T*y - g
48-
#5: 0 ~ x^2 + y^2 - L^2
49-
# ----
50-
#6: D(eq:5) -> 0 ~ 2xx'+ 2yy'
51-
#7: D(eq:1) -> D(D(x)) ~ D(w) -> D(xˍt) ~ D(w) -> D(xˍt) ~ T*x
52-
#8: D(eq:2) -> D(D(y)) ~ D(z) -> D(y_t) ~ T*y - g
53-
#9: D(eq:6) -> 0 ~ 2xx'' + 2x'x' + 2yy'' + 2y'y'
54-
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
55-
@test eq_to_diff == [7, 8, N, N, 6, 9, N, N, N]
56-
end
57-
5838
using ModelingToolkit
5939
@parameters t L g
6040
@variables x(t) y(t) w(t) z(t) T(t) xˍt(t) yˍt(t)

test/structural_transformation/tearing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ graph2vars(graph) = map(is->Set(map(i->int2var[i], is)), graph.fadjlist)
6161
state = TearingState(tearing(sys))
6262
let sss = state.structure
6363
@unpack graph = sss
64-
@test graph2vars(graph) == [Set([u5])]
64+
@test graph2vars(graph) == [Set([u1, u2, u5])]
6565
end
6666

6767
# Before:

test/structural_transformation/utils.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@ pendulum = ODESystem(eqs, t, [x, y, w, z, T], [L, g], name=:pendulum)
1919
state = TearingState(pendulum)
2020
StructuralTransformations.find_solvables!(state)
2121
sss = state.structure
22-
@unpack graph, solvable_graph, fullvars, var_to_diff = sss
23-
@test isequal(fullvars, [D(x), D(y), D(w), D(z), x, y, w, z, T])
22+
@unpack graph, solvable_graph, var_to_diff = sss
2423
@test graph.fadjlist == [[1, 7], [2, 8], [3, 5, 9], [4, 6, 9], [5, 6]]
25-
@test graph.badjlist == 9 == length(fullvars)
24+
@test graph.badjlist == 9
2625
@test ne(graph) == nnz(incidence_matrix(graph)) == 12
2726
@test nv(solvable_graph) == 9 + 5
2827
let N = nothing

0 commit comments

Comments
 (0)