Skip to content

Commit fe0f93a

Browse files
committed
fix: fix linearization tests
1 parent 0f6d9ba commit fe0f93a

File tree

7 files changed

+46
-42
lines changed

7 files changed

+46
-42
lines changed

src/inputoutput.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
163163
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
164164
sys::AbstractODESystem,
165165
inputs = unbound_inputs(sys),
166-
disturbance_inputs = nothing;
166+
disturbance_inputs = Any[];
167167
implicit_dae = false,
168168
simplify = false,
169169
)
@@ -289,7 +289,7 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
289289
push!(new_fullvars, v)
290290
end
291291
end
292-
ninputs == 0 && return (state, 1:0)
292+
ninputs == 0 && return state
293293

294294
nvars = ndsts(graph) - ninputs
295295
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
@@ -318,14 +318,13 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
318318
@set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
319319
ps = parameters(sys)
320320

321-
if io !== nothing
322-
inputs, = io
321+
if inputsyms !== nothing
323322
# Change order of new parameters to correspond to user-provided order in argument `inputs`
324323
d = Dict{Any, Int}()
325324
for (i, inp) in enumerate(new_parameters)
326325
d[inp] = i
327326
end
328-
permutation = [d[i] for i in inputs]
327+
permutation = [d[i] for i in inputsyms]
329328
new_parameters = new_parameters[permutation]
330329
end
331330

@@ -334,8 +333,7 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
334333
@set! state.sys = sys
335334
@set! state.fullvars = new_fullvars
336335
@set! state.structure = structure
337-
base_params = length(ps)
338-
return state, (base_params + 1):(base_params + length(new_parameters)) # (1:length(new_parameters)) .+ base_params
336+
return state
339337
end
340338

341339
"""
@@ -361,7 +359,7 @@ function get_disturbance_system(dist::DisturbanceModel{<:ODESystem})
361359
end
362360

363361
"""
364-
(f_oop, f_ip), augmented_sys, dvs, p = add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
362+
(f_oop, f_ip), augmented_sys, dvs, p = add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[])
365363
366364
Add a model of an unmeasured disturbance to `sys`. The disturbance model is an instance of [`DisturbanceModel`](@ref).
367365
@@ -410,7 +408,7 @@ model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.i
410408
411409
`f_oop` will have an extra state corresponding to the integrator in the disturbance model. This state will not be affected by any input, but will affect the dynamics from where it enters, in this case it will affect additively from `model.torque.tau.u`.
412410
"""
413-
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kwargs...)
411+
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[]; kwargs...)
414412
t = get_iv(sys)
415413
@variables d(t)=0 [disturbance = true]
416414
@variables u(t)=0 [input = true] # New system input

src/linearization.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function linearization_function(sys::AbstractSystem, inputs = unbound_inputs(sys
120120
end
121121

122122
lin_fun = LinearizationFunction(
123-
diff_idxs, alge_idxs, length(unknowns(sys)),
123+
diff_idxs, alge_idxs, inputs, length(unknowns(sys)),
124124
prob, h, u0 === nothing ? nothing : similar(u0), uf_jac, h_jac, pf_jac,
125125
hp_jac, initializealg, initialization_kwargs)
126126
return lin_fun
@@ -397,7 +397,7 @@ Construct a `LinearizationProblem` for linearizing the system `sys` with the giv
397397
All other keyword arguments are forwarded to `linearization_function`.
398398
"""
399399
function LinearizationProblem(sys::AbstractSystem, inputs, outputs; t = 0.0, kwargs...)
400-
linfun, _ = linearization_function(sys, inputs, outputs; kwargs...)
400+
linfun = linearization_function(sys, inputs, outputs; kwargs...)
401401
return LinearizationProblem(linfun, t)
402402
end
403403

@@ -764,7 +764,7 @@ Permute the state representation of `sys` obtained from [`linearize`](@ref) so t
764764
Example:
765765
766766
```
767-
lsys, ssys = linearize(pid, [reference.u, measurement.u], [ctr_output.u])
767+
lsys = linearize(pid, [reference.u, measurement.u], [ctr_output.u])
768768
desired_order = [int.x, der.x] # Unknowns that are present in unknowns(ssys)
769769
lsys = ModelingToolkit.reorder_unknowns(lsys, unknowns(ssys), desired_order)
770770
```

src/systems/analysis_points.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,13 @@ Add an input without an additional output variable.
596596
PerturbOutput(ap::AnalysisPoint) = PerturbOutput(ap, false)
597597

598598
function apply_transformation(tf::PerturbOutput, sys::AbstractSystem)
599+
@show "ok"
600+
@show tf.ap
599601
modify_nested_subsystem(sys, tf.ap) do ap_sys
600602
# get analysis point
603+
@show tf.ap
601604
ap_idx = analysis_point_index(ap_sys, tf.ap)
605+
@show ap_idx
602606
ap_idx === nothing &&
603607
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
604608
# modified equations

src/systems/diffeqs/odesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ Generates a function that computes the observed value(s) `ts` in the system `sys
472472
- `eval_expression = false`: If true and `expression = false`, evaluates the returned function in the module `eval_module`
473473
- `output_type = Array` the type of the array generated by a out-of-place vector-valued function
474474
- `param_only = false` if true, only allow the generated function to access system parameters
475-
- `inputs = nothing` additinoal symbolic variables that should be provided to the generated function
475+
- `inputs = Any[]` additional symbolic variables that should be provided to the generated function
476476
- `checkbounds = true` checks bounds if true when destructuring parameters
477477
- `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
478478
- `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist.
@@ -502,8 +502,8 @@ For example, a function `g(op, unknowns, p..., inputs, t)` will be the in-place
502502
an array of inputs `inputs` is given, and `param_only` is false for a time-dependent system.
503503
"""
504504
function build_explicit_observed_function(sys, ts;
505-
inputs = nothing,
506-
disturbance_inputs = nothing,
505+
inputs = Any[],
506+
disturbance_inputs = Any[],
507507
disturbance_argument = false,
508508
expression = false,
509509
eval_expression = false,
@@ -576,13 +576,13 @@ function build_explicit_observed_function(sys, ts;
576576
else
577577
(unknowns(sys),)
578578
end
579-
if inputs === nothing
579+
if isempty(inputs)
580580
inputs = ()
581581
else
582582
ps = setdiff(ps, inputs) # Inputs have been converted to parameters, remove those from the parameter list
583583
inputs = (inputs,)
584584
end
585-
if disturbance_inputs !== nothing
585+
if !isempty(disturbance_inputs)
586586
# Disturbance inputs may or may not be included as inputs, depending on disturbance_argument
587587
ps = setdiff(ps, disturbance_inputs)
588588
end

src/systems/systems.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ topological sort of the observed equations in `sys`.
2929
function mtkbuild(
3030
sys::AbstractSystem; additional_passes = [], simplify = false, split = true,
3131
allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true,
32-
inputs = nothing, outputs = nothing,
33-
disturbance_inputs = nothing,
32+
inputs = Any[], outputs = Any[],
33+
disturbance_inputs = Any[],
3434
kwargs...)
3535
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
3636
newsys′ = __structural_simplification(sys; simplify,
@@ -74,8 +74,8 @@ function __structural_simplification(sys::SDESystem, args...; kwargs...)
7474
end
7575

7676
function __structural_simplification(sys::AbstractSystem; simplify = false,
77-
inputs = nothing, outputs = nothing,
78-
disturbance_inputs = nothing,
77+
inputs = Any[], outputs = Any[],
78+
disturbance_inputs = Any[],
7979
kwargs...)
8080
sys = expand_connections(sys)
8181
state = TearingState(sys; sort_eqs)

src/systems/systemstructure.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,8 @@ end
659659

660660
function structural_simplification!(state::TearingState; simplify = false,
661661
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
662-
inputs = nothing, outputs = nothing,
663-
disturbance_inputs = nothing,
662+
inputs = Any[], outputs = Any[],
663+
disturbance_inputs = Any[],
664664
kwargs...)
665665

666666
if state.sys isa ODESystem
@@ -672,7 +672,7 @@ function structural_simplification!(state::TearingState; simplify = false,
672672
cont_inputs = [inputs; clocked_inputs[continuous_id]]
673673
sys = _structural_simplification!(tss[continuous_id]; simplify,
674674
check_consistency, fully_determined,
675-
cont_inputs, outputs, disturbance_inputs,
675+
inputs = cont_inputs, outputs, disturbance_inputs,
676676
kwargs...)
677677
if length(tss) > 1
678678
if continuous_id > 0
@@ -690,7 +690,7 @@ function structural_simplification!(state::TearingState; simplify = false,
690690
end
691691
disc_inputs = [inputs; clocked_inputs[i]]
692692
ss, = _structural_simplification!(state; simplify, check_consistency,
693-
disc_inputs, outputs, disturbance_inputs,
693+
inputs = disc_inputs, outputs, disturbance_inputs,
694694
fully_determined, kwargs...)
695695
append!(appended_parameters, inputs[i], unknowns(ss))
696696
discrete_subsystems[i] = ss
@@ -717,8 +717,8 @@ end
717717
function _structural_simplification!(state::TearingState; simplify = false,
718718
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
719719
dummy_derivative = true,
720-
inputs = nothing, outputs = nothing,
721-
disturbance_inputs = nothing,
720+
inputs = Any[], outputs = Any[],
721+
disturbance_inputs = Any[],
722722
kwargs...)
723723
if fully_determined isa Bool
724724
check_consistency &= fully_determined

test/downstream/linearize.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ lsys3 = linearize(sys, [r], [y]; autodiff = AutoFiniteDiff())
2525
@test lsys.C[] == lsys2.C[] == lsys3.C[] == 1
2626
@test lsys.D[] == lsys2.D[] == lsys3.D[] == 0
2727

28-
lsys, ssys = linearize(sys, [r], [r])
28+
lsys = linearize(sys, [r], [r])
2929

3030
@test lsys.A[] == -2
3131
@test lsys.B[] == 1
3232
@test lsys.C[] == 0
3333
@test lsys.D[] == 1
3434

35-
lsys, ssys = linearize(sys, r, r) # Test allow scalars
35+
lsys = linearize(sys, r, r) # Test allow scalars
3636

3737
@test lsys.A[] == -2
3838
@test lsys.B[] == 1
@@ -89,19 +89,19 @@ connections = [f.y ~ c.r # filtered reference to controller reference
8989
@named cl = ODESystem(connections, t, systems = [f, c, p])
9090
cl = mtkbuild(cl, inputs = [f.u], outputs = [p.x])
9191

92-
lsys0, ssys = linearize(cl)
92+
lsys0 = linearize(cl, [f.u], [p.x])
9393
desired_order = [f.x, p.x]
94-
lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(ssys), desired_order)
95-
lsys1, ssys = linearize(cl; autodiff = AutoFiniteDiff())
96-
lsys2 = ModelingToolkit.reorder_unknowns(lsys1, unknowns(ssys), desired_order)
94+
lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(cl), desired_order)
95+
lsys1 = linearize(cl, [f.u], [p.x]; autodiff = AutoFiniteDiff())
96+
lsys2 = ModelingToolkit.reorder_unknowns(lsys1, unknowns(cl), desired_order)
9797

9898
@test lsys.A == lsys2.A == [-2 0; 1 -2]
9999
@test lsys.B == lsys2.B == reshape([1, 0], 2, 1)
100100
@test lsys.C == lsys2.C == [0 1]
101101
@test lsys.D[] == lsys2.D[] == 0
102102

103103
## Symbolic linearization
104-
lsyss, _ = ModelingToolkit.linearize_symbolic(cl, [f.u], [p.x])
104+
lsyss = ModelingToolkit.linearize_symbolic(cl, [f.u], [p.x])
105105

106106
@test ModelingToolkit.fixpoint_sub(lsyss.A, ModelingToolkit.defaults(cl)) == lsys.A
107107
@test ModelingToolkit.fixpoint_sub(lsyss.B, ModelingToolkit.defaults(cl)) == lsys.B
@@ -116,11 +116,12 @@ Nd = 10
116116
@named pid = LimPID(; k, Ti, Td, Nd)
117117

118118
@unpack reference, measurement, ctr_output = pid
119-
lsys0, ssys = linearize(pid, [reference.u, measurement.u], [ctr_output.u];
119+
pid = mtkbuild(pid, inputs = [reference.u, measurement.u], outputs = [ctr_output.u])
120+
lsys0 = linearize(pid, [reference.u, measurement.u], [ctr_output.u];
120121
op = Dict(reference.u => 0.0, measurement.u => 0.0))
121122
@unpack int, der = pid
122123
desired_order = [int.x, der.x]
123-
lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(ssys), desired_order)
124+
lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(pid), desired_order)
124125

125126
@test lsys.A == [0 0; 0 -10]
126127
@test lsys.B == [2 -2; 10 -10]
@@ -150,12 +151,12 @@ lsys = ModelingToolkit.reorder_unknowns(lsys, desired_order, reverse(desired_ord
150151

151152
## Test that there is a warning when input is misspecified
152153
if VERSION >= v"1.8"
153-
@test_throws "Some specified inputs were not found" linearize(pid,
154+
@test_throws "Some parameters are missing from the variable map." linearize(pid,
154155
[
155156
pid.reference.u,
156157
pid.measurement.u
157158
], [ctr_output.u])
158-
@test_throws "Some specified outputs were not found" linearize(pid,
159+
@test_throws "Some parameters are missing from the variable map." linearize(pid,
159160
[
160161
reference.u,
161162
measurement.u
@@ -186,15 +187,16 @@ function saturation(; y_max, y_min = y_max > 0 ? -y_max : -Inf, name)
186187
ODESystem(eqs, t, name = name)
187188
end
188189
@named sat = saturation(; y_max = 1)
190+
sat = mtkbuild(sat, inputs = [u], outputs = [y])
189191
# inside the linear region, the function is identity
190192
@unpack u, y = sat
191-
lsys, ssys = linearize(sat, [u], [y])
193+
lsys = linearize(sat, [u], [y])
192194
@test isempty(lsys.A) # there are no differential variables in this system
193195
@test isempty(lsys.B)
194196
@test isempty(lsys.C)
195197
@test lsys.D[] == 1
196198

197-
@test_skip lsyss, _ = ModelingToolkit.linearize_symbolic(sat, [u], [y]) # Code gen replaces ifelse with if statements causing symbolic evaluation to fail
199+
@test_skip lsyss = ModelingToolkit.linearize_symbolic(sat, [u], [y]) # Code gen replaces ifelse with if statements causing symbolic evaluation to fail
198200
# @test substitute(lsyss.A, ModelingToolkit.defaults(sat)) == lsys.A
199201
# @test substitute(lsyss.B, ModelingToolkit.defaults(sat)) == lsys.B
200202
# @test substitute(lsyss.C, ModelingToolkit.defaults(sat)) == lsys.C
@@ -267,9 +269,9 @@ closed_loop = ODESystem(connections, t, systems = [model, pid, filt, sensor, r,
267269
filt.x => 0.0,
268270
filt.xd => 0.0
269271
])
270-
closed_loop = mtkbuild(closed_loop, inputs = :r, outputs = :y)
272+
closed_loop = mtkbuild(closed_loop)
271273

272-
@test_nowarn linearize(closed_loop; warn_empty_op = false)
274+
@test_nowarn linearize(closed_loop, :r, :y; warn_empty_op = false)
273275

274276
# https://discourse.julialang.org/t/mtk-change-in-linearize/115760/3
275277
@mtkmodel Tank_noi begin

0 commit comments

Comments
 (0)