Skip to content

Commit 9626fee

Browse files
committed
improve robustness in generate_control_function
by reusing `io_preprocessing`
1 parent 4c31ede commit 9626fee

File tree

2 files changed

+11
-37
lines changed

2 files changed

+11
-37
lines changed

src/inputoutput.jl

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -170,38 +170,28 @@ The return values also include the remaining states and parameters, in the order
170170
# Example
171171
```
172172
using ModelingToolkit: generate_control_function, varmap_to_vars, defaults
173-
f, dvs, ps = generate_control_function(sys, expression=Val{false}, simplify=true)
173+
f, dvs, ps = generate_control_function(sys, expression=Val{false}, simplify=false)
174174
p = varmap_to_vars(defaults(sys), ps)
175175
x = varmap_to_vars(defaults(sys), dvs)
176176
t = 0
177177
f[1](x, inputs, p, t)
178178
```
179179
"""
180-
function generate_control_function(sys::AbstractODESystem;
180+
function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys);
181181
implicit_dae = false,
182-
has_difference = false,
183-
simplify = true,
182+
simplify = false,
184183
kwargs...)
185-
ctrls = unbound_inputs(sys)
186-
if isempty(ctrls)
184+
if isempty(inputs)
187185
error("No unbound inputs were found in system.")
188186
end
189187

190-
# One can either connect unbound inputs to new parameters and allow structural_simplify, but then the unbound inputs appear as states :( .
191-
# One can also just remove them from the states and parameters for the purposes of code generation, but then structural_simplify fails :(
192-
# To have the best of both worlds, all unbound inputs must be converted to `@parameters` in which case structural_simplify handles them correctly :)
193-
sys = toparam(sys, ctrls)
194-
195-
if simplify
196-
sys = structural_simplify(sys)
197-
end
188+
sys, diff_idxs, alge_idxs = io_preprocessing(sys, inputs, []; simplify,
189+
check_bound = false, kwargs...)
198190

199191
dvs = states(sys)
200192
ps = parameters(sys)
201-
202-
dvs = setdiff(dvs, ctrls)
203-
ps = setdiff(ps, ctrls)
204-
inputs = map(x -> time_varying_as_func(value(x), sys), ctrls)
193+
ps = setdiff(ps, inputs)
194+
inputs = map(x -> time_varying_as_func(value(x), sys), inputs)
205195

206196
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
207197
check_operator_variables(eqs, Differential)
@@ -223,24 +213,10 @@ function generate_control_function(sys::AbstractODESystem;
223213
end
224214
pre, sol_states = get_substitutions_and_solved_states(sys)
225215
f = build_function(rhss, args...; postprocess_fbody = pre, states = sol_states,
226-
kwargs...)
216+
expression = Val{false}, kwargs...)
227217
f, dvs, ps
228218
end
229219

230-
"""
231-
toparam(sys, ctrls::AbstractVector)
232-
233-
Transform all instances of `@varibales` in `ctrls` appearing as states and in equations of `sys` with similarly named `@parameters`. This allows [`structural_simplify`](@ref)(sys) in the presence unbound inputs.
234-
"""
235-
function toparam(sys, ctrls::AbstractVector)
236-
eqs = equations(sys)
237-
subs = Dict(ctrls .=> toparam.(ctrls))
238-
eqs = map(eqs) do eq
239-
substitute(eq.lhs, subs) ~ substitute(eq.rhs, subs)
240-
end
241-
ODESystem(eqs, name = nameof(sys))
242-
end
243-
244220
function inputs_to_parameters!(state::TransformationState, check_bound = true)
245221
@unpack structure, fullvars, sys = state
246222
@unpack var_to_diff, graph, solvable_graph = structure

test/input_output_handling.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ eqs = [
108108
]
109109

110110
@named sys = ODESystem(eqs)
111-
f, dvs, ps = ModelingToolkit.generate_control_function(sys, expression = Val{false},
112-
simplify = true)
111+
f, dvs, ps = ModelingToolkit.generate_control_function(sys, simplify = true)
113112

114113
@test isequal(dvs[], x)
115114
@test isempty(ps)
@@ -170,8 +169,7 @@ eqs = [connect_sd(sd, mass1, mass2)
170169
@named _model = ODESystem(eqs, t)
171170
@named model = compose(_model, mass1, mass2, sd);
172171

173-
f, dvs, ps = ModelingToolkit.generate_control_function(model, expression = Val{false},
174-
simplify = true)
172+
f, dvs, ps = ModelingToolkit.generate_control_function(model, simplify = true)
175173
@test length(dvs) == 4
176174
@test length(ps) == length(parameters(model))
177175
p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps)

0 commit comments

Comments
 (0)