Skip to content

Commit 31f7a54

Browse files
Merge pull request #3273 from SciML/disturbance_args
add option to include disturbance args in `generate_control_function`
2 parents ba842c2 + a79f4c2 commit 31f7a54

File tree

3 files changed

+111
-30
lines changed

3 files changed

+111
-30
lines changed

src/inputoutput.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
160160
# Build control function
161161

162162
"""
163-
(f_oop, f_ip), x_sym, p, io_sys = generate_control_function(
163+
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
164164
sys::AbstractODESystem,
165165
inputs = unbound_inputs(sys),
166166
disturbance_inputs = nothing;
@@ -177,8 +177,7 @@ f_ip : (xout,x,u,p,t) -> nothing
177177
178178
The return values also include the chosen state-realization (the remaining unknowns) `x_sym` and parameters, in the order they appear as arguments to `f`.
179179
180-
If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement.
181-
See [`add_input_disturbance`](@ref) for a higher-level interface to this functionality.
180+
If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will (by default) not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement. To add an input argument corresponding to the disturbance inputs, either include the disturbance inputs among the control inputs, or set `disturbance_argument=true`, in which case an additional input argument `w` is added to the generated function `(x,u,p,t,w)->rhs`.
182181
183182
!!! note "Un-simplified system"
184183
This function expects `sys` to be un-simplified, i.e., `structural_simplify` or `@mtkbuild` should not be called on the system before passing it into this function. `generate_control_function` calls a special version of `structural_simplify` internally.
@@ -196,6 +195,7 @@ f[1](x, inputs, p, t)
196195
"""
197196
function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys),
198197
disturbance_inputs = disturbances(sys);
198+
disturbance_argument = false,
199199
implicit_dae = false,
200200
simplify = false,
201201
eval_expression = false,
@@ -219,10 +219,11 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
219219
# ps = [ps; disturbance_inputs]
220220
end
221221
inputs = map(x -> time_varying_as_func(value(x), sys), inputs)
222+
disturbance_inputs = unwrap.(disturbance_inputs)
222223

223224
eqs = [eq for eq in full_equations(sys)]
224225
eqs = map(subs_constants, eqs)
225-
if disturbance_inputs !== nothing
226+
if disturbance_inputs !== nothing && !disturbance_argument
226227
# Set all disturbance *inputs* to zero (we just want to keep the disturbance state)
227228
subs = Dict(disturbance_inputs .=> 0)
228229
eqs = [eq.lhs ~ substitute(eq.rhs, subs) for eq in eqs]
@@ -239,16 +240,24 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
239240
t = get_iv(sys)
240241

241242
# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
242-
243-
args = (u, inputs, p..., t)
243+
if disturbance_argument
244+
args = (u, inputs, p..., t, disturbance_inputs)
245+
else
246+
args = (u, inputs, p..., t)
247+
end
244248
if implicit_dae
245249
ddvs = map(Differential(get_iv(sys)), dvs)
246250
args = (ddvs, args...)
247251
end
248252
process = get_postprocess_fbody(sys)
253+
wrapped_arrays_vars = disturbance_argument ?
254+
wrap_array_vars(
255+
sys, rhss; dvs, ps, inputs, extra_args = (disturbance_inputs,)) :
256+
wrap_array_vars(sys, rhss; dvs, ps, inputs)
249257
f = build_function(rhss, args...; postprocess_fbody = process,
250-
expression = Val{true}, wrap_code = wrap_mtkparameters(sys, false, 3) .∘
251-
wrap_array_vars(sys, rhss; dvs, ps, inputs) .∘
258+
expression = Val{true}, wrap_code = wrap_mtkparameters(
259+
sys, false, 3, Int(disturbance_argument) + 1) .∘
260+
wrapped_arrays_vars .∘
252261
wrap_parameter_dependencies(sys, false),
253262
kwargs...)
254263
f = eval_or_rgf.(f; eval_expression, eval_module)

src/systems/abstractsystem.jl

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,33 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230230
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
231231
end
232232

233+
"""
234+
$(TYPEDSIGNATURES)
235+
236+
Add the necessary assignment statements to allow use of unscalarized array variables
237+
in the generated code. `expr` is the expression returned by the function. `dvs` and
238+
`ps` are the unknowns and parameters of the system `sys` to use in the generated code.
239+
`inputs` can be specified as an array of symbolics if the generated function has inputs.
240+
If `history == true`, the generated function accepts a history function. `cachesyms` are
241+
extra variables (arrays of variables) stored in the cache array(s) of the parameter
242+
object. `extra_args` are extra arguments appended to the end of the argument list.
243+
244+
The function is assumed to have the signature `f(du, u, h, x, p, cache_syms..., t, extra_args...)`
245+
Where:
246+
- `du` is the optional buffer to write to for in-place functions.
247+
- `u` is the list of unknowns. This argument is not present if `dvs === nothing`.
248+
- `h` is the optional history function, present if `history == true`.
249+
- `x` is the array of inputs, present only if `inputs !== nothing`. Values are assumed
250+
to be in the order of variables passed to `inputs`.
251+
- `p` is the parameter object.
252+
- `cache_syms` are the cache variables. These are part of the splatted parameter object.
253+
- `t` is time, present only if the system is time dependent.
254+
- `extra_args` are the extra arguments passed to the function, present only if
255+
`extra_args` is non-empty.
256+
"""
233257
function wrap_array_vars(
234258
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
235-
inputs = nothing, history = false, cachesyms::Tuple = ())
259+
inputs = nothing, history = false, cachesyms::Tuple = (), extra_args::Tuple = ())
236260
isscalar = !(exprs isa AbstractArray)
237261
var_to_arridxs = Dict()
238262

@@ -252,6 +276,10 @@ function wrap_array_vars(
252276
if inputs !== nothing
253277
rps = (inputs, rps...)
254278
end
279+
if has_iv(sys)
280+
rps = (rps..., get_iv(sys))
281+
end
282+
rps = (rps..., extra_args...)
255283
for sym in reduce(vcat, rps; init = [])
256284
iscall(sym) && operation(sym) == getindex || continue
257285
arg = arguments(sym)[1]
@@ -332,7 +360,7 @@ end
332360
const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___)
333361

334362
"""
335-
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
363+
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2, offset = Int(is_time_dependent(sys)))
336364
337365
Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
338366
allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
@@ -342,12 +370,14 @@ the first parameter vector in the out-of-place version of the function. For exam
342370
if a history function (DDEs) was passed before `p`, then the function before wrapping
343371
would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
344372
373+
`offset` is the number of arguments at the end of the argument list to ignore. Defaults
374+
to 1 if the system is time-dependent (to ignore `t`) and 0 otherwise.
375+
345376
The returned function is `identity` if the system does not have an `IndexCache`.
346377
"""
347-
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
378+
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2,
379+
offset = Int(is_time_dependent(sys)))
348380
if has_index_cache(sys) && get_index_cache(sys) !== nothing
349-
offset = Int(is_time_dependent(sys))
350-
351381
if isscalar
352382
function (expr)
353383
param_args = expr.args[p_start:(end - offset)]

test/input_output_handling.jl

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -153,24 +153,66 @@ if VERSION >= v"1.8" # :opaque_closure not supported before
153153
end
154154

155155
## Code generation with unbound inputs
156+
@testset "generate_control_function with disturbance inputs" begin
157+
for split in [true, false]
158+
simplify = true
159+
160+
@variables x(t)=0 u(t)=0 [input = true]
161+
eqs = [
162+
D(x) ~ -x + u
163+
]
164+
165+
@named sys = ODESystem(eqs, t)
166+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split)
167+
168+
@test isequal(dvs[], x)
169+
@test isempty(ps)
170+
171+
p = nothing
172+
x = [rand()]
173+
u = [rand()]
174+
@test f[1](x, u, p, 1) == -x + u
175+
176+
# With disturbance inputs
177+
@variables x(t)=0 u(t)=0 [input = true] d(t)=0
178+
eqs = [
179+
D(x) ~ -x + u + d^2
180+
]
181+
182+
@named sys = ODESystem(eqs, t)
183+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(
184+
sys, [u], [d]; simplify, split)
185+
186+
@test isequal(dvs[], x)
187+
@test isempty(ps)
188+
189+
p = nothing
190+
x = [rand()]
191+
u = [rand()]
192+
@test f[1](x, u, p, 1) == -x + u
193+
194+
## With added d argument
195+
@variables x(t)=0 u(t)=0 [input = true] d(t)=0
196+
eqs = [
197+
D(x) ~ -x + u + d^2
198+
]
199+
200+
@named sys = ODESystem(eqs, t)
201+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(
202+
sys, [u], [d]; simplify, split, disturbance_argument = true)
203+
204+
@test isequal(dvs[], x)
205+
@test isempty(ps)
206+
207+
p = nothing
208+
x = [rand()]
209+
u = [rand()]
210+
d = [rand()]
211+
@test f[1](x, u, p, t, d) == -x + u + [d[]^2]
212+
end
213+
end
156214

157-
@variables x(t)=0 u(t)=0 [input = true]
158-
eqs = [
159-
D(x) ~ -x + u
160-
]
161-
162-
@named sys = ODESystem(eqs, t)
163-
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)
164-
165-
@test isequal(dvs[], x)
166-
@test isempty(ps)
167-
168-
p = nothing
169-
x = [rand()]
170-
u = [rand()]
171-
@test f[1](x, u, p, 1) == -x + u
172-
173-
# more complicated system
215+
## more complicated system
174216

175217
@variables u(t) [input = true]
176218

0 commit comments

Comments
 (0)