Skip to content

Commit bf3551e

Browse files
authored
add input signals to build_explicit_observed_function (#2199)
1 parent 9869b2b commit bf3551e

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,11 @@ Build the observed function assuming the observed equations are all explicit,
304304
i.e. there are no cycles.
305305
"""
306306
function build_explicit_observed_function(sys, ts;
307+
inputs = nothing,
307308
expression = false,
308309
output_type = Array,
309310
checkbounds = true,
311+
drop_expr = drop_expr,
310312
throw = true)
311313
if (isscalar = !(ts isa AbstractVector))
312314
ts = [ts]
@@ -378,9 +380,18 @@ function build_explicit_observed_function(sys, ts;
378380
push!(obsexprs, lhs rhs)
379381
end
380382

383+
pars = parameters(sys)
384+
if inputs !== nothing
385+
pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
386+
end
387+
ps = DestructuredArgs(pars, inbounds = !checkbounds)
381388
dvs = DestructuredArgs(states(sys), inbounds = !checkbounds)
382-
ps = DestructuredArgs(parameters(sys), inbounds = !checkbounds)
383-
args = [dvs, ps, ivs...]
389+
if inputs === nothing
390+
args = [dvs, ps, ivs...]
391+
else
392+
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
393+
args = [dvs, ipts, ps, ivs...]
394+
end
384395
pre = get_postprocess_fbody(sys)
385396

386397
ex = Func(args, [],

test/input_output_handling.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,33 @@ t = ModelingToolkitStandardLibrary.Mechanical.Rotational.t
127127
@named spring = Spring(; c = 10)
128128
@named damper = Damper(; d = 3)
129129
@named torque = Torque()
130+
@variables y(t) = 0
130131
eqs = [connect(torque.flange, inertia1.flange_a)
131132
connect(inertia1.flange_b, spring.flange_a, damper.flange_a)
132-
connect(inertia2.flange_a, spring.flange_b, damper.flange_b)]
133+
connect(inertia2.flange_a, spring.flange_b, damper.flange_b)
134+
y ~ inertia2.w + torque.tau.u]
133135
model = ODESystem(eqs, t; systems = [torque, inertia1, inertia2, spring, damper],
134136
name = :name)
135137
model_outputs = [inertia1.w, inertia2.w, inertia1.phi, inertia2.phi]
136138
model_inputs = [torque.tau.u]
137139
matrices, ssys = linearize(model, model_inputs, model_outputs)
138140
@test length(ModelingToolkit.outputs(ssys)) == 4
139141

142+
if VERSION >= v"1.8" # :opaque_closure not supported before
143+
matrices, ssys = linearize(model, model_inputs, [y])
144+
A, B, C, D = matrices
145+
obsf = ModelingToolkit.build_explicit_observed_function(ssys,
146+
[y],
147+
inputs = [torque.tau.u],
148+
drop_expr = identity)
149+
x = randn(size(A, 1))
150+
u = randn(size(B, 2))
151+
p = getindex.(Ref(ModelingToolkit.defaults(ssys)), parameters(ssys))
152+
y1 = obsf(x, u, p, 0)
153+
y2 = C * x + D * u
154+
@test y1[] y2[]
155+
end
156+
140157
## Code generation with unbound inputs
141158

142159
@variables t x(t)=0 u(t)=0 [input = true]

0 commit comments

Comments
 (0)