Skip to content

Commit e8a9f04

Browse files
Merge pull request #3220 from SciML/infiniteopt
add capability to trace MTK dynamics with InfiniteOpt
2 parents 8e48f65 + 7200999 commit e8a9f04

File tree

5 files changed

+135
-0
lines changed

5 files changed

+135
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
6464
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6565
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
6666
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
67+
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
6768
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
6869

6970
[extensions]
@@ -72,6 +73,7 @@ MTKChainRulesCoreExt = "ChainRulesCore"
7273
MTKDeepDiffsExt = "DeepDiffs"
7374
MTKHomotopyContinuationExt = "HomotopyContinuation"
7475
MTKLabelledArraysExt = "LabelledArrays"
76+
MTKInfiniteOptExt = "InfiniteOpt"
7577

7678
[compat]
7779
AbstractTrees = "0.3, 0.4"
@@ -104,6 +106,7 @@ FunctionWrappers = "1.1"
104106
FunctionWrappersWrappers = "0.1"
105107
Graphs = "1.5.2"
106108
HomotopyContinuation = "2.11"
109+
InfiniteOpt = "0.5"
107110
InteractiveUtils = "1"
108111
JuliaFormatter = "1.0.47"
109112
JumpProcesses = "9.13.1"

ext/MTKInfiniteOptExt.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module MTKInfiniteOptExt
2+
import ModelingToolkit
3+
import SymbolicUtils
4+
import NaNMath
5+
import InfiniteOpt
6+
import InfiniteOpt: JuMP, GeneralVariableRef
7+
8+
# This file contains method definitions to make it possible to trace through functions generated by MTK using JuMP variables
9+
10+
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
11+
f = nameof(ff)
12+
# These need to be defined so that JuMP can trace through functions built by Symbolics
13+
@eval NaNMath.$f(x::GeneralVariableRef) = Base.$f(x)
14+
end
15+
16+
# JuMP variables and Symbolics variables never compare equal. When tracing through dynamics, a function argument can be either a JuMP variable or A Symbolics variable, it can never be both.
17+
function Base.isequal(::SymbolicUtils.Symbolic,
18+
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr})
19+
false
20+
end
21+
function Base.isequal(
22+
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr},
23+
::SymbolicUtils.Symbolic)
24+
false
25+
end
26+
end

test/extensions/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
44
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
55
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
66
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
7+
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
8+
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
9+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
710
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
811
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
912
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"

test/extensions/test_infiniteopt.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
using ModelingToolkit, InfiniteOpt, JuMP, Ipopt
2+
using ModelingToolkit: D_nounits as D, t_nounits as t, varmap_to_vars
3+
4+
@mtkmodel Pendulum begin
5+
@parameters begin
6+
g = 9.8
7+
L = 0.4
8+
K = 1.2
9+
m = 0.3
10+
end
11+
@variables begin
12+
θ(t) # state
13+
ω(t) # state
14+
τ(t) = 0 # input
15+
y(t) # output
16+
end
17+
@equations begin
18+
D(θ) ~ ω
19+
D(ω) ~ -g / L * sin(θ) - K / m * ω + τ / m / L^2
20+
y ~ θ * 180 / π
21+
end
22+
end
23+
@named model = Pendulum()
24+
model = complete(model)
25+
26+
inputs = [model.τ]
27+
(f_oop, f_ip), dvs, psym, io_sys = ModelingToolkit.generate_control_function(
28+
model, inputs, split = false)
29+
30+
outputs = [model.y]
31+
f_obs = ModelingToolkit.build_explicit_observed_function(io_sys, outputs; inputs = inputs)
32+
33+
expected_state_order = [model.θ, model.ω]
34+
permutation = [findfirst(isequal(x), expected_state_order) for x in dvs] # This maps our expected state order to the actual state order
35+
36+
##
37+
38+
ub = varmap_to_vars([model.θ => 2pi, model.ω => 10], dvs)
39+
lb = varmap_to_vars([model.θ => -2pi, model.ω => -10], dvs)
40+
xf = varmap_to_vars([model.θ => pi, model.ω => 0], dvs)
41+
nx = length(dvs)
42+
nu = length(inputs)
43+
ny = length(outputs)
44+
45+
##
46+
m = InfiniteModel(optimizer_with_attributes(Ipopt.Optimizer,
47+
"print_level" => 0, "acceptable_tol" => 1e-3, "constr_viol_tol" => 1e-5, "max_iter" => 1000,
48+
"tol" => 1e-5, "mu_strategy" => "monotone", "nlp_scaling_method" => "gradient-based",
49+
"alpha_for_y" => "safer-min-dual-infeas", "bound_mult_init_method" => "mu-based", "print_user_options" => "yes"));
50+
51+
@infinite_parameter(m, τ in [0, 1], num_supports=51,
52+
derivative_method=OrthogonalCollocation(4)) # Time variable
53+
guess_xs = [t -> pi, t -> 0.1][permutation]
54+
guess_us = [t -> 0.1]
55+
InfiniteOpt.@variables(m,
56+
begin
57+
# state variables
58+
(lb[i] <= x[i = 1:nx] <= ub[i], Infinite(τ), start = guess_xs[i]) # state variables
59+
-10 <= u[i = 1:nu] <= 10, Infinite(τ), (start = guess_us[i]) # control variables
60+
0 <= tf <= 10, (start = 5) # Final time
61+
0.2 <= L <= 0.6, (start = 0.4) # Length parameter
62+
end)
63+
64+
# Trace the dynamics
65+
x0, p = ModelingToolkit.get_u0_p(io_sys, [model.θ => 0, model.ω => 0], [model.L => L])
66+
67+
xp = f_oop(x, u, p, τ)
68+
cp = f_obs(x, u, p, τ) # Test that it's possible to trace through an observed function
69+
70+
@objective(m, Min, tf)
71+
@constraint(m, [i = 1:nx], x[i](0)==x0[i]) # Initial condition
72+
@constraint(m, [i = 1:nx], x[i](1)==xf[i]) # Terminal state
73+
74+
x_scale = varmap_to_vars([model.θ => 1
75+
model.ω => 1], dvs)
76+
77+
# Add dynamics constraints
78+
@constraint(m, [i = 1:nx], ((x[i], τ) - tf * xp[i]) / x_scale[i]==0)
79+
80+
optimize!(m)
81+
82+
# Extract the optimal solution
83+
opt_tf = value(tf)
84+
opt_time = opt_tf * value(τ)
85+
opt_x = [value(x[i]) for i in permutation]
86+
opt_u = [value(u[i]) for i in 1:nu]
87+
opt_L = value(L)
88+
89+
# Plot the results
90+
# using Plots
91+
# plot(opt_time, opt_x[1], label = "θ", xlabel = "Time [s]", layout=3)
92+
# plot!(opt_time, opt_x[2], label = "ω", sp=2)
93+
# plot!(opt_time, opt_u[1], label = "τ", sp=3)
94+
95+
using Test
96+
@test opt_x[1][end]pi atol=1e-3
97+
@test opt_x[2][end]0 atol=1e-3
98+
99+
@test opt_x[1][1]0 atol=1e-3
100+
@test opt_x[2][1]0 atol=1e-3
101+
102+
@test opt_L0.2 atol=1e-3 # Smallest permissible length is optimal

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,6 @@ end
116116
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
117117
@safetestset "LabelledArrays Test" include("labelledarrays.jl")
118118
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
119+
@safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl")
119120
end
120121
end

0 commit comments

Comments
 (0)