@@ -10,14 +10,14 @@ $(FIELDS)
10
10
11
11
```julia
12
12
@variables x y z
13
- @parameters σ ρ β
13
+ @parameters a b c
14
14
15
- op = σ *(y-x) + x*(ρ -z)-y + x*y - β *z
16
- @named os = OptimizationSystem(op, [x,y,z],[σ,ρ,β ])
15
+ op = a *(y-x) + x*(b -z)-y + x*y - c *z
16
+ @named os = OptimizationSystem(op, [x,y,z], [a,b,c ])
17
17
```
18
18
"""
19
19
struct OptimizationSystem <: AbstractTimeIndependentSystem
20
- """ Vector of equations defining the system."""
20
+ """ Objective function of the system."""
21
21
op:: Any
22
22
""" Unknown variables."""
23
23
states:: Vector
@@ -26,18 +26,15 @@ struct OptimizationSystem <: AbstractTimeIndependentSystem
26
26
""" Array variables."""
27
27
var_to_name:: Any
28
28
observed:: Vector{Equation}
29
- constraints:: Vector
30
- """
31
- Name: the name of the system. These are required to have unique names.
32
- """
29
+ """ List of constraint equations of the system."""
30
+ constraints:: Vector # {Union{Equation,Inequality}}
31
+ """ The unique name of the system."""
33
32
name:: Symbol
34
- """
35
- systems: The internal systems
36
- """
33
+ """ The internal systems."""
37
34
systems:: Vector{OptimizationSystem}
38
35
"""
39
- defaults: The default values to use when initial conditions and/or
40
- parameters are not supplied in `ODEProblem `.
36
+ The default values to use when initial guess and/or
37
+ parameters are not supplied in `OptimizationProblem `.
41
38
"""
42
39
defaults:: Dict
43
40
"""
@@ -48,7 +45,7 @@ struct OptimizationSystem <: AbstractTimeIndependentSystem
48
45
constraints, name, systems, defaults, metadata = nothing ;
49
46
checks:: Union{Bool, Int} = true )
50
47
if checks == true || (checks & CheckUnits) > 0
51
- check_units (op)
48
+ unwrap (op) isa Symbolic && check_units (op)
52
49
check_units (observed)
53
50
all_dimensionless ([states; ps]) || check_units (constraints)
54
51
end
@@ -69,6 +66,11 @@ function OptimizationSystem(op, states, ps;
69
66
metadata = nothing )
70
67
name === nothing &&
71
68
throw (ArgumentError (" The `name` keyword must be provided. Please consider using the `@named` macro" ))
69
+
70
+ constraints = value .(scalarize (constraints))
71
+ states′ = value .(scalarize (states))
72
+ ps′ = value .(scalarize (ps))
73
+
72
74
if ! (isempty (default_u0) && isempty (default_p))
73
75
Base. depwarn (" `default_u0` and `default_p` are deprecated. Use `defaults` instead." ,
74
76
:OptimizationSystem , force = true )
@@ -80,12 +82,12 @@ function OptimizationSystem(op, states, ps;
80
82
defaults = todict (defaults)
81
83
defaults = Dict (value (k) => value (v) for (k, v) in pairs (defaults))
82
84
83
- states, ps = value .(states), value .(ps)
84
85
var_to_name = Dict ()
85
- process_variables! (var_to_name, defaults, states)
86
- process_variables! (var_to_name, defaults, ps)
86
+ process_variables! (var_to_name, defaults, states′ )
87
+ process_variables! (var_to_name, defaults, ps′ )
87
88
isempty (observed) || collect_var_to_name! (var_to_name, (eq. lhs for eq in observed))
88
- OptimizationSystem (value (op), states, ps, var_to_name,
89
+
90
+ OptimizationSystem (value (op), states′, ps′, var_to_name,
89
91
observed,
90
92
constraints,
91
93
name, systems, defaults, metadata; checks = checks)
@@ -124,10 +126,38 @@ function generate_function(sys::OptimizationSystem, vs = states(sys), ps = param
124
126
end
125
127
126
128
function equations (sys:: OptimizationSystem )
127
- isempty (get_systems (sys)) ? get_op (sys) :
128
- get_op (sys) + reduce (+ , namespace_expr .(get_systems (sys)))
129
+ op = get_op (sys)
130
+ systems = get_systems (sys)
131
+ if isempty (systems)
132
+ op
133
+ else
134
+ op + reduce (+ , map (sys_ -> namespace_expr (get_op (sys_), sys_), systems))
135
+ end
136
+ end
137
+
138
+ namespace_constraint (eq:: Equation , sys) = namespace_equation (eq, sys)
139
+
140
+ # namespace_constraint(ineq::Inequality, sys) = namespace_inequality(ineq, sys)
141
+
142
+ # function namespace_inequality(ineq::Inequality, sys, n = nameof(sys))
143
+ # _lhs = namespace_expr(ineq.lhs, sys, n)
144
+ # _rhs = namespace_expr(ineq.rhs, sys, n)
145
+ # Inequality(
146
+ # namespace_expr(_lhs, sys, n),
147
+ # namespace_expr(_rhs, sys, n),
148
+ # ineq.relational_op,
149
+ # )
150
+ # end
151
+
152
+ function namespace_constraints (sys:: OptimizationSystem )
153
+ namespace_constraint .(get_constraints (sys), Ref (sys))
154
+ end
155
+
156
+ function constraints (sys:: OptimizationSystem )
157
+ cs = get_constraints (sys)
158
+ systems = get_systems (sys)
159
+ isempty (systems) ? cs : [cs; reduce (vcat, namespace_constraints .(systems))]
129
160
end
130
- namespace_expr (sys:: OptimizationSystem ) = namespace_expr (get_op (sys), sys)
131
161
132
162
hessian_sparsity (sys:: OptimizationSystem ) = hessian_sparsity (get_op (sys), states (sys))
133
163
@@ -168,6 +198,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
168
198
kwargs... ) where {iip}
169
199
dvs = states (sys)
170
200
ps = parameters (sys)
201
+ cstr = constraints (sys)
171
202
172
203
defs = defaults (sys)
173
204
defs = mergedefaults (defs, parammap, ps)
@@ -216,8 +247,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
216
247
hess_prototype = nothing
217
248
end
218
249
219
- if length (sys . constraints ) > 0
220
- @named cons_sys = NonlinearSystem (sys . constraints , dvs, ps)
250
+ if length (cstr ) > 0
251
+ @named cons_sys = NonlinearSystem (cstr , dvs, ps)
221
252
cons = generate_function (cons_sys, checkbounds = checkbounds,
222
253
linenumbers = linenumbers,
223
254
expression = Val{false })[2 ]
@@ -237,6 +268,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
237
268
238
269
_f = DiffEqBase. OptimizationFunction {iip} (f,
239
270
sys = sys,
271
+ syms = nameof .(states (sys)),
240
272
SciMLBase. NoAD ();
241
273
grad = _grad,
242
274
hess = _hess,
@@ -251,6 +283,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
251
283
else
252
284
_f = DiffEqBase. OptimizationFunction {iip} (f,
253
285
sys = sys,
286
+ syms = nameof .(states (sys)),
254
287
SciMLBase. NoAD ();
255
288
grad = _grad,
256
289
hess = _hess,
0 commit comments