Skip to content

Commit 979d6be

Browse files
authored
Merge pull request #1680 from SciML/fb_simpler_control2
refactor simplification and `generate_control_function`
2 parents 84210d0 + 9626fee commit 979d6be

File tree

3 files changed

+48
-70
lines changed

3 files changed

+48
-70
lines changed

src/inputoutput.jl

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -170,38 +170,28 @@ The return values also include the remaining states and parameters, in the order
170170
# Example
171171
```
172172
using ModelingToolkit: generate_control_function, varmap_to_vars, defaults
173-
f, dvs, ps = generate_control_function(sys, expression=Val{false}, simplify=true)
173+
f, dvs, ps = generate_control_function(sys, expression=Val{false}, simplify=false)
174174
p = varmap_to_vars(defaults(sys), ps)
175175
x = varmap_to_vars(defaults(sys), dvs)
176176
t = 0
177177
f[1](x, inputs, p, t)
178178
```
179179
"""
180-
function generate_control_function(sys::AbstractODESystem;
180+
function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys);
181181
implicit_dae = false,
182-
has_difference = false,
183-
simplify = true,
182+
simplify = false,
184183
kwargs...)
185-
ctrls = unbound_inputs(sys)
186-
if isempty(ctrls)
184+
if isempty(inputs)
187185
error("No unbound inputs were found in system.")
188186
end
189187

190-
# One can either connect unbound inputs to new parameters and allow structural_simplify, but then the unbound inputs appear as states :( .
191-
# One can also just remove them from the states and parameters for the purposes of code generation, but then structural_simplify fails :(
192-
# To have the best of both worlds, all unbound inputs must be converted to `@parameters` in which case structural_simplify handles them correctly :)
193-
sys = toparam(sys, ctrls)
194-
195-
if simplify
196-
sys = structural_simplify(sys)
197-
end
188+
sys, diff_idxs, alge_idxs = io_preprocessing(sys, inputs, []; simplify,
189+
check_bound = false, kwargs...)
198190

199191
dvs = states(sys)
200192
ps = parameters(sys)
201-
202-
dvs = setdiff(dvs, ctrls)
203-
ps = setdiff(ps, ctrls)
204-
inputs = map(x -> time_varying_as_func(value(x), sys), ctrls)
193+
ps = setdiff(ps, inputs)
194+
inputs = map(x -> time_varying_as_func(value(x), sys), inputs)
205195

206196
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
207197
check_operator_variables(eqs, Differential)
@@ -223,24 +213,10 @@ function generate_control_function(sys::AbstractODESystem;
223213
end
224214
pre, sol_states = get_substitutions_and_solved_states(sys)
225215
f = build_function(rhss, args...; postprocess_fbody = pre, states = sol_states,
226-
kwargs...)
216+
expression = Val{false}, kwargs...)
227217
f, dvs, ps
228218
end
229219

230-
"""
231-
toparam(sys, ctrls::AbstractVector)
232-
233-
Transform all instances of `@varibales` in `ctrls` appearing as states and in equations of `sys` with similarly named `@parameters`. This allows [`structural_simplify`](@ref)(sys) in the presence unbound inputs.
234-
"""
235-
function toparam(sys, ctrls::AbstractVector)
236-
eqs = equations(sys)
237-
subs = Dict(ctrls .=> toparam.(ctrls))
238-
eqs = map(eqs) do eq
239-
substitute(eq.lhs, subs) ~ substitute(eq.rhs, subs)
240-
end
241-
ODESystem(eqs, name = nameof(sys))
242-
end
243-
244220
function inputs_to_parameters!(state::TransformationState, check_bound = true)
245221
@unpack structure, fullvars, sys = state
246222
@unpack var_to_diff, graph, solvable_graph = structure

src/systems/abstractsystem.jl

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -948,10 +948,17 @@ function will be applied during the tearing process. It also takes kwargs
948948
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
949949
types during tearing.
950950
"""
951-
function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
951+
function structural_simplify(sys::AbstractSystem, args...; kwargs...)
952952
sys = expand_connections(sys)
953953
state = TearingState(sys)
954-
state, = inputs_to_parameters!(state)
954+
sys, input_idxs = _structural_simplify(sys, state, args...; kwargs...)
955+
sys
956+
end
957+
958+
function _structural_simplify(sys::AbstractSystem, state; simplify = false,
959+
check_bound = true,
960+
kwargs...)
961+
state, input_idxs = inputs_to_parameters!(state, check_bound)
955962
sys = alias_elimination!(state)
956963
state = TearingState(sys)
957964
check_consistency(state)
@@ -964,7 +971,31 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
964971
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
965972
@set! sys.observed = topsort_equations(observed(sys), fullstates)
966973
invalidate_cache!(sys)
967-
return sys
974+
return sys, input_idxs
975+
end
976+
977+
function io_preprocessing(sys::AbstractSystem, inputs,
978+
outputs; simplify = false, kwargs...)
979+
sys = expand_connections(sys)
980+
state = TearingState(sys)
981+
markio!(state, inputs, outputs)
982+
sys, input_idxs = _structural_simplify(sys, state; simplify, check_bound = false,
983+
kwargs...)
984+
985+
eqs = equations(sys)
986+
check_operator_variables(eqs, Differential)
987+
# Sort equations and states such that diff.eqs. match differential states and the rest are algebraic
988+
diffstates = collect_operator_variables(sys, Differential)
989+
eqs = sort(eqs, by = e -> !isoperator(e.lhs, Differential),
990+
alg = Base.Sort.DEFAULT_STABLE)
991+
@set! sys.eqs = eqs
992+
diffstates = [arguments(e.lhs)[1] for e in eqs[1:length(diffstates)]]
993+
sts = [diffstates; setdiff(states(sys), diffstates)]
994+
@set! sys.states = sts
995+
diff_idxs = 1:length(diffstates)
996+
alge_idxs = (length(diffstates) + 1):length(sts)
997+
998+
sys, diff_idxs, alge_idxs, input_idxs
968999
end
9691000

9701001
"""
@@ -994,36 +1025,9 @@ See also [`linearize`](@ref) which provides a higher-level interface.
9941025
function linearization_function(sys::AbstractSystem, inputs,
9951026
outputs; simplify = false,
9961027
kwargs...)
997-
sys = expand_connections(sys)
998-
state = TearingState(sys)
999-
markio!(state, inputs, outputs)
1000-
state, input_idxs = inputs_to_parameters!(state, false)
1001-
sys = alias_elimination!(state)
1002-
state = TearingState(sys)
1003-
check_consistency(state)
1004-
if sys isa ODESystem
1005-
sys = dae_order_lowering(dummy_derivative(sys, state))
1006-
end
1007-
state = TearingState(sys)
1008-
find_solvables!(state; kwargs...)
1009-
sys = tearing_reassemble(state, tearing(state), simplify = simplify)
1010-
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
1011-
@set! sys.observed = topsort_equations(observed(sys), fullstates)
1012-
invalidate_cache!(sys)
1013-
1014-
eqs = equations(sys)
1015-
check_operator_variables(eqs, Differential)
1016-
# Sort equations and states such that diff.eqs. match differential states and the rest are algebraic
1017-
diffstates = collect_operator_variables(sys, Differential)
1018-
eqs = sort(eqs, by = e -> !isoperator(e.lhs, Differential),
1019-
alg = Base.Sort.DEFAULT_STABLE)
1020-
@set! sys.eqs = eqs
1021-
diffstates = [arguments(e.lhs)[1] for e in eqs[1:length(diffstates)]]
1022-
sts = [diffstates; setdiff(states(sys), diffstates)]
1023-
@set! sys.states = sts
1024-
1025-
diff_idxs = 1:length(diffstates)
1026-
alge_idxs = (length(diffstates) + 1):length(sts)
1028+
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
1029+
kwargs...)
1030+
sts = states(sys)
10271031
fun = ODEFunction(sys)
10281032
lin_fun = let fun = fun,
10291033
h = ModelingToolkit.build_explicit_observed_function(sys, outputs)

test/input_output_handling.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ eqs = [
108108
]
109109

110110
@named sys = ODESystem(eqs)
111-
f, dvs, ps = ModelingToolkit.generate_control_function(sys, expression = Val{false},
112-
simplify = true)
111+
f, dvs, ps = ModelingToolkit.generate_control_function(sys, simplify = true)
113112

114113
@test isequal(dvs[], x)
115114
@test isempty(ps)
@@ -170,8 +169,7 @@ eqs = [connect_sd(sd, mass1, mass2)
170169
@named _model = ODESystem(eqs, t)
171170
@named model = compose(_model, mass1, mass2, sd);
172171

173-
f, dvs, ps = ModelingToolkit.generate_control_function(model, expression = Val{false},
174-
simplify = true)
172+
f, dvs, ps = ModelingToolkit.generate_control_function(model, simplify = true)
175173
@test length(dvs) == 4
176174
@test length(ps) == length(parameters(model))
177175
p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps)

0 commit comments

Comments
 (0)