Skip to content

Commit f29824e

Browse files
committed
Format
1 parent fc2a309 commit f29824e

File tree

7 files changed

+69
-57
lines changed

7 files changed

+69
-57
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ using SymbolicUtils: maketerm, iscall
1111

1212
using ModelingToolkit
1313
using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Differential,
14-
unknowns, equations, vars, Symbolic, diff2term_with_unit, shift2term_with_unit, value,
14+
unknowns, equations, vars, Symbolic, diff2term_with_unit,
15+
shift2term_with_unit, value,
1516
operation, arguments, Sym, Term, simplify, symbolic_linear_solve,
1617
isdiffeq, isdifferential, isirreducible,
1718
empty_substitutions, get_substitutions,
@@ -22,7 +23,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2223
get_postprocess_fbody, vars!,
2324
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2425
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
25-
filter_kwargs, lower_varname_with_unit, lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
26+
filter_kwargs, lower_varname_with_unit,
27+
lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
2628
get_fullvars, has_equations, observed,
2729
Schedule, schedule
2830

src/structural_transformation/symbolics_tearing.jl

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ called dummy derivatives.
248248
State selection is done. All non-differentiated variables are algebraic
249249
variables, and all variables that appear differentiated are differential variables.
250250
"""
251-
function substitute_derivatives_algevars!(ts::TearingState, neweqs, var_eq_matching, dummy_sub; iv = nothing, D = nothing)
251+
function substitute_derivatives_algevars!(
252+
ts::TearingState, neweqs, var_eq_matching, dummy_sub; iv = nothing, D = nothing)
252253
@unpack fullvars, sys, structure = ts
253254
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
254255
diff_to_var = invview(var_to_diff)
@@ -288,7 +289,7 @@ end
288289
#=
289290
There are three cases where we want to generate new variables to convert
290291
the system into first order (semi-implicit) ODEs.
291-
292+
292293
1. To first order:
293294
Whenever higher order differentiated variable like `D(D(D(x)))` appears,
294295
we introduce new variables `x_t`, `x_tt`, and `x_ttt` and new equations
@@ -364,7 +365,8 @@ Effects on the system structure:
364365
- solvable_graph:
365366
- var_eq_matching: match D(x) to the added identity equation D(x) ~ x_t
366367
"""
367-
function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching; mm = nothing, iv = nothing, D = nothing)
368+
function generate_derivative_variables!(
369+
ts::TearingState, neweqs, var_eq_matching; mm = nothing, iv = nothing, D = nothing)
368370
@unpack fullvars, sys, structure = ts
369371
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
370372
eq_var_matching = invview(var_eq_matching)
@@ -395,7 +397,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
395397
dx = fullvars[dv]
396398
order, lv = var_order(dv, diff_to_var)
397399
x_t = is_discrete ? lower_shift_varname_with_unit(fullvars[dv], iv) :
398-
lower_varname_with_unit(fullvars[lv], iv, order)
400+
lower_varname_with_unit(fullvars[lv], iv, order)
399401

400402
# Add `x_t` to the graph
401403
v_t = add_dd_variable!(structure, fullvars, x_t, dv)
@@ -405,7 +407,7 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
405407
# Update matching
406408
push!(var_eq_matching, unassigned)
407409
var_eq_matching[dv] = unassigned
408-
eq_var_matching[dummy_eq] = dv
410+
eq_var_matching[dummy_eq] = dv
409411
end
410412
end
411413

@@ -428,7 +430,7 @@ function find_duplicate_dd(dv, solvable_graph, diff_to_var, linear_eqs, mm)
428430
return eq, v_t
429431
end
430432
end
431-
return nothing
433+
return nothing
432434
end
433435

434436
"""
@@ -492,8 +494,9 @@ Order the new equations and variables such that the differential equations
492494
and variables come first. Return the new equations, the solved equations,
493495
the new orderings, and the number of solved variables and equations.
494496
"""
495-
function generate_system_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, iv = nothing, D = nothing)
496-
@unpack fullvars, sys, structure = state
497+
function generate_system_equations!(state::TearingState, neweqs, var_eq_matching;
498+
simplify = false, iv = nothing, D = nothing)
499+
@unpack fullvars, sys, structure = state
497500
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
498501
eq_var_matching = invview(var_eq_matching)
499502
diff_to_var = invview(var_to_diff)
@@ -502,11 +505,12 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
502505
if is_only_discrete(structure)
503506
for (i, v) in enumerate(fullvars)
504507
op = operation(v)
505-
op isa Shift && (op.steps < 0) && begin
506-
lowered = lower_shift_varname_with_unit(v, iv)
507-
total_sub[v] = lowered
508-
fullvars[i] = lowered
509-
end
508+
op isa Shift && (op.steps < 0) &&
509+
begin
510+
lowered = lower_shift_varname_with_unit(v, iv)
511+
total_sub[v] = lowered
512+
fullvars[i] = lowered
513+
end
510514
end
511515
end
512516

@@ -581,10 +585,11 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
581585
end
582586
solved_vars_set = BitSet(solved_vars)
583587
var_ordering = [diff_vars;
584-
setdiff!(setdiff(1:ndsts(graph), diff_vars_set),
585-
solved_vars_set)]
588+
setdiff!(setdiff(1:ndsts(graph), diff_vars_set),
589+
solved_vars_set)]
586590

587-
return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars), length(solved_vars_set)
591+
return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars),
592+
length(solved_vars_set)
588593
end
589594

590595
"""
@@ -648,7 +653,8 @@ Eliminate the solved variables and equations from the graph and permute the
648653
graph's vertices to account for the new variable/equation ordering.
649654
"""
650655
# TODO: BLT sorting
651-
function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_ordering, nsolved_eq, nsolved_var)
656+
function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering,
657+
var_ordering, nsolved_eq, nsolved_var)
652658
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
653659

654660
eqsperm = zeros(Int, nsrcs(graph))
@@ -692,7 +698,8 @@ end
692698
"""
693699
Update the system equations, unknowns, and observables after simplification.
694700
"""
695-
function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
701+
function update_simplified_system!(
702+
state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
696703
cse_hack = true, array_hack = true)
697704
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
698705
diff_to_var = invview(var_to_diff)
@@ -732,7 +739,6 @@ function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dumm
732739
sys = schedule(sys)
733740
end
734741

735-
736742
"""
737743
Give the order of the variable indexed by dv.
738744
"""
@@ -790,12 +796,14 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
790796

791797
generate_derivative_variables!(state, neweqs, var_eq_matching; mm, iv, D)
792798

793-
neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var =
794-
generate_system_equations!(state, neweqs, var_eq_matching; simplify, iv, D)
799+
neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = generate_system_equations!(
800+
state, neweqs, var_eq_matching; simplify, iv, D)
795801

796-
state = reorder_vars!(state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
802+
state = reorder_vars!(
803+
state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
797804

798-
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; cse_hack, array_hack)
805+
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
806+
extra_unknowns; cse_hack, array_hack)
799807

800808
@set! state.sys = sys
801809
@set! sys.tearing_state = state

src/structural_transformation/utils.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,15 +477,17 @@ function shift2term(var)
477477
backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps
478478

479479
num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁
480-
ds = join([Char(0x209c), Char(0x208b), num])
480+
ds = join([Char(0x209c), Char(0x208b), num])
481481
# Char(0x209c) = ₜ
482482
# Char(0x208b) = ₋ (subscripted minus)
483483

484484
O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg
485485
oldop = operation(O)
486-
newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) : Symbol(string(nameof(oldop)))
486+
newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) :
487+
Symbol(string(nameof(oldop)))
487488

488-
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
489+
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname),
490+
Symbolics.children(O), Symbolics.metadata(O))
489491
newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
490492
newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O)
491493
newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift)

src/systems/discrete_system/discrete_system.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,15 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
269269
for k in collect(keys(u0map))
270270
v = u0map[k]
271271
if !((op = operation(k)) isa Shift)
272-
isnothing(getunshifted(k)) && error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
273-
272+
isnothing(getunshifted(k)) &&
273+
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
274+
274275
updated[Shift(iv, 1)(k)] = v
275276
elseif op.steps > 0
276277
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).")
277278
else
278279
updated[Shift(iv, op.steps + 1)(only(arguments(k)))] = v
279280
end
280-
281281
end
282282
for var in unknowns(sys)
283283
op = operation(var)

src/systems/systemstructure.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,9 @@ function shift_discrete_system(ts::TearingState)
473473
end
474474
iv = get_iv(sys)
475475

476-
discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))
477-
for k in discvars
478-
if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
476+
discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))
477+
for k in discvars
478+
if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
479479

480480
for i in eachindex(fullvars)
481481
fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute(

src/variables.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ function default_toterm(x)
138138
if iscall(x) && (op = operation(x)) isa Operator
139139
if !(op isa Differential)
140140
if op isa Shift && op.steps < 0
141-
return shift2term(x)
141+
return shift2term(x)
142142
end
143143
x = normalize_to_differential(op)(arguments(x)...)
144144
end

test/discrete_system.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,6 @@ k = ShiftIndex(t)
257257
@named sys = DiscreteSystem([x ~ x^2 + y^2, y ~ x(k - 1) + y(k - 1)], t)
258258
@test_throws ["algebraic equations", "not yet supported"] structural_simplify(sys)
259259

260-
261260
@testset "Passing `nothing` to `u0`" begin
262261
@variables x(t) = 1
263262
k = ShiftIndex()
@@ -273,11 +272,11 @@ end
273272
prob = DiscreteProblem(de, [], (0, 10))
274273
@test prob[x] == 2.0
275274
@test prob[x(k - 1)] == 1.0
276-
275+
277276
# must provide initial conditions for history
278277
@test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10))
279-
@test_throws ErrorException DiscreteProblem(de, [x(k+1) => 2.], (0, 10))
280-
278+
@test_throws ErrorException DiscreteProblem(de, [x(k + 1) => 2.0], (0, 10))
279+
281280
# initial values only affect _that timestep_, not the entire history
282281
prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10))
283282
@test prob[x] == 3.0
@@ -286,34 +285,35 @@ end
286285
@test prob[xₜ₋₁] == 2.0
287286

288287
# Test initial assignment with lowered variable
289-
prob = DiscreteProblem(de, [xₜ₋₁(k-1) => 4.0], (0, 10))
290-
@test prob[x(k-1)] == prob[xₜ₋₁] == 1.0
291-
@test prob[x] == 5.
288+
prob = DiscreteProblem(de, [xₜ₋₁(k - 1) => 4.0], (0, 10))
289+
@test prob[x(k - 1)] == prob[xₜ₋₁] == 1.0
290+
@test prob[x] == 5.0
292291

293292
# Test missing initial throws error
294293
@variables x(t)
295-
@mtkbuild de = DiscreteSystem([x ~ x(k-1) + x(k-2)*x(k-3)], t)
296-
@test_throws ErrorException prob = DiscreteProblem(de, [x(k-3) => 2.], (0, 10))
297-
@test_throws ErrorException prob = DiscreteProblem(de, [x(k-3) => 2., x(k-1) => 3.], (0, 10))
294+
@mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2) * x(k - 3)], t)
295+
@test_throws ErrorException prob=DiscreteProblem(de, [x(k - 3) => 2.0], (0, 10))
296+
@test_throws ErrorException prob=DiscreteProblem(
297+
de, [x(k - 3) => 2.0, x(k - 1) => 3.0], (0, 10))
298298

299299
# Test non-assigned initials are given default value
300-
@variables x(t) = 2.
301-
@mtkbuild de = DiscreteSystem([x ~ x(k-1) + x(k-2)*x(k-3)], t)
302-
prob = DiscreteProblem(de, [x(k-3) => 12.], (0, 10))
300+
@variables x(t) = 2.0
301+
@mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2) * x(k - 3)], t)
302+
prob = DiscreteProblem(de, [x(k - 3) => 12.0], (0, 10))
303303
@test prob[x] == 26.0
304-
@test prob[x(k-1)] == 2.0
305-
@test prob[x(k-2)] == 2.0
304+
@test prob[x(k - 1)] == 2.0
305+
@test prob[x(k - 2)] == 2.0
306306

307307
# Elaborate test
308308
@variables xₜ₋₂(t) zₜ₋₁(t) z(t)
309-
eqs = [x ~ x(k-1) + z(k-2),
310-
z ~ x(k-2) * x(k-3) - z(k-1)^2]
309+
eqs = [x ~ x(k - 1) + z(k - 2),
310+
z ~ x(k - 2) * x(k - 3) - z(k - 1)^2]
311311
@mtkbuild de = DiscreteSystem(eqs, t)
312-
u0 = [x(k-1) => 3,
313-
xₜ₋₂(k-1) => 4,
314-
x(k-2) => 1,
315-
z(k-1) => 5,
316-
zₜ₋₁(k-1) => 12]
312+
u0 = [x(k - 1) => 3,
313+
xₜ₋₂(k - 1) => 4,
314+
x(k - 2) => 1,
315+
z(k - 1) => 5,
316+
zₜ₋₁(k - 1) => 12]
317317
prob = DiscreteProblem(de, u0, (0, 10))
318318
@test prob[x] == 15
319319
@test prob[z] == -21

0 commit comments

Comments
 (0)