Skip to content

Commit a261271

Browse files
committed
add generate_control_function
1 parent 39d8f09 commit a261271

File tree

2 files changed

+176
-1
lines changed

2 files changed

+176
-1
lines changed

src/inputoutput.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,91 @@ end
158158

159159
has_var(ex, x) = x Set(get_variables(ex))
160160

161+
# Build control function
162+
163+
"""
164+
(f_oop, f_ip), dvs, p = generate_control_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); implicit_dae = false, ddvs = if implicit_dae
165+
166+
For a system `sys` that has unbound inputs (as determined by [`unbound_inputs`](@ref)), generate a function with additional input argument `in`
167+
```
168+
f_oop : (u,in,p,t) -> rhs
169+
f_ip : (uout,u,in,p,t) -> nothing
170+
```
171+
The return values also include the remaining states and parameters, in the order they appear as arguments to `f`.
172+
173+
# Example
174+
```
175+
using ModelingToolkit: generate_control_function, varmap_to_vars, defaults
176+
f, dvs, ps = generate_control_function(sys, expression=Val{false}, simplify=true)
177+
p = varmap_to_vars(defaults(sys), ps)
178+
x = varmap_to_vars(defaults(sys), dvs)
179+
t = 0
180+
f[1](x, inputs, p, t)
181+
```
182+
"""
183+
function generate_control_function(
184+
sys::AbstractODESystem;
185+
implicit_dae=false,
186+
has_difference=false,
187+
simplify=true,
188+
kwargs...
189+
)
190+
191+
ctrls = unbound_inputs(sys)
192+
if isempty(ctrls)
193+
error("No unbound inputs were found in system.")
194+
end
195+
196+
# One can either connect unbound inputs to new parameters and allow structural_simplify, but then the unbound inputs appear as states :( .
197+
# One can also just remove them from the states and parameters for the purposes of code generation, but then structural_simplify fails :(
198+
# To have the best of both worlds, all unbound inputs must be converted to `@parameters` in which case structural_simplify handles them correctly :)
199+
sys = toparam(sys, ctrls)
200+
201+
if simplify
202+
sys = structural_simplify(sys)
203+
end
204+
205+
dvs = states(sys)
206+
ps = parameters(sys)
207+
208+
dvs = setdiff(dvs, ctrls)
209+
ps = setdiff(ps, ctrls)
210+
inputs = map(x->time_varying_as_func(value(x), sys), ctrls)
211+
212+
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
213+
foreach(check_derivative_variables, eqs)
214+
# substitute x(t) by just x
215+
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
216+
[eq.rhs for eq in eqs]
217+
218+
219+
# TODO: add an optional check on the ordering of observed equations
220+
u = map(x->time_varying_as_func(value(x), sys), dvs)
221+
p = map(x->time_varying_as_func(value(x), sys), ps)
222+
t = get_iv(sys)
223+
224+
# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
225+
226+
args = (u, inputs, p, t)
227+
if implicit_dae
228+
ddvs = map(Differential(get_iv(sys)), dvs)
229+
args = (ddvs, args...)
230+
end
231+
pre, sol_states = get_substitutions_and_solved_states(sys)
232+
f = build_function(rhss, args...; postprocess_fbody=pre, states=sol_states, kwargs...)
233+
f, dvs, ps
234+
end
235+
236+
"""
237+
toparam(sys, ctrls::AbstractVector)
238+
239+
Transform all instances of `@varibales` in `ctrls` appearing as states and in equations of `sys` with similarly named `@parameters`. This allows [`structural_simplify`](@ref)(sys) in the presence unbound inputs.
240+
"""
241+
function toparam(sys, ctrls::AbstractVector)
242+
eqs = equations(sys)
243+
subs = Dict(ctrls .=> toparam.(ctrls))
244+
eqs = map(eqs) do eq
245+
substitute(eq.lhs, subs) ~ substitute(eq.rhs, subs)
246+
end
247+
ODESystem(eqs, name=sys.name)
248+
end

test/input_output_handling.jl

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,91 @@ syss = structural_simplify(sys2)
8282
@test is_bound(syss, sys.y)
8383

8484
@test isequal(unbound_outputs(syss), [y])
85-
@test isequal(bound_outputs(syss), [sys.y])
85+
@test isequal(bound_outputs(syss), [sys.y])
86+
87+
88+
## Code generation with unbound inputs
89+
90+
@variables t x(t)=0 u(t)=0 [input=true]
91+
D = Differential(t)
92+
eqs = [
93+
D(x) ~ -x + u
94+
]
95+
96+
@named sys = ODESystem(eqs)
97+
f, dvs, ps = ModelingToolkit.generate_control_function(sys, expression=Val{false}, simplify=true)
98+
99+
@test isequal(dvs[], x)
100+
@test isempty(ps)
101+
102+
p = []
103+
x = [rand()]
104+
u = [rand()]
105+
@test f[1](x,u,p,1) == -x + u
106+
107+
108+
# more complicated system
109+
110+
@variables u(t) [input=true]
111+
112+
function Mass(; name, m = 1.0, p = 0, v = 0)
113+
@variables y(t) [output=true]
114+
ps = @parameters m=m
115+
sts = @variables pos(t)=p vel(t)=v
116+
eqs = [
117+
D(pos) ~ vel
118+
y ~ pos
119+
]
120+
ODESystem(eqs, t, [pos, vel], ps; name)
121+
end
122+
123+
function Spring(; name, k = 1e4)
124+
ps = @parameters k=k
125+
@variables x(t)=0 # Spring deflection
126+
ODESystem(Equation[], t, [x], ps; name)
127+
end
128+
129+
function Damper(; name, c = 10)
130+
ps = @parameters c=c
131+
@variables vel(t)=0
132+
ODESystem(Equation[], t, [vel], ps; name)
133+
end
134+
135+
function SpringDamper(; name, k=false, c=false)
136+
spring = Spring(; name=:spring, k)
137+
damper = Damper(; name=:damper, c)
138+
compose(
139+
ODESystem(Equation[], t; name),
140+
spring, damper)
141+
end
142+
143+
144+
connect_sd(sd, m1, m2) = [sd.spring.x ~ m1.pos - m2.pos, sd.damper.vel ~ m1.vel - m2.vel]
145+
sd_force(sd) = -sd.spring.k * sd.spring.x - sd.damper.c * sd.damper.vel
146+
147+
# Parameters
148+
m1 = 1
149+
m2 = 1
150+
k = 1000
151+
c = 10
152+
153+
@named mass1 = Mass(; m=m1)
154+
@named mass2 = Mass(; m=m2)
155+
@named sd = SpringDamper(; k, c)
156+
157+
eqs = [
158+
connect_sd(sd, mass1, mass2)
159+
D(mass1.vel) ~ ( sd_force(sd) + u) / mass1.m
160+
D(mass2.vel) ~ (-sd_force(sd)) / mass2.m
161+
]
162+
@named _model = ODESystem(eqs, t)
163+
@named model = compose(_model, mass1, mass2, sd);
164+
165+
166+
f, dvs, ps = ModelingToolkit.generate_control_function(model, expression=Val{false}, simplify=true)
167+
@test length(dvs) == 4
168+
@test length(ps) == length(parameters(model))
169+
p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps)
170+
x = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), dvs)
171+
u = [rand()]
172+
@test f[1](x,u,p,1) == [u;0;0;0]

0 commit comments

Comments
 (0)