Skip to content

Commit 6cd5bb1

Browse files
fix
1 parent 0f7388e commit 6cd5bb1

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

src/systems/optimization/constraints_system.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ end
178178
function generate_function(sys::ConstraintsSystem, dvs = states(sys), ps = parameters(sys);
179179
kwargs...)
180180
lhss = generate_canonical_form_lhss(sys)
181-
func = build_function(lhss, value.(dvs), value.(ps))
181+
pre, sol_states = get_substitutions_and_solved_states(sys)
182+
183+
func = build_function(lhss, value.(dvs), value.(ps); postprocess_fbody = pre,
184+
states = sol_states, kwargs...)
182185

183186
cstr = constraints(sys)
184187
lcons = fill(-Inf, length(cstr))
@@ -206,3 +209,14 @@ g(x) <= 0
206209
function generate_canonical_form_lhss(sys)
207210
lhss = subs_constants([Symbolics.canonical_form(eq).lhs for eq in constraints(sys)])
208211
end
212+
213+
function get_cmap(sys::ConstraintsSystem)
214+
#Inject substitutions for constants => values
215+
cs = collect_constants([get_constraints(sys); get_observed(sys)]) #ctrls? what else?
216+
if !empty_substitutions(sys)
217+
cs = [cs; collect_constants(get_substitutions(sys).subs)]
218+
end
219+
# Swap constants for their values
220+
cmap = map(x -> x ~ getdefault(x), cs)
221+
return cmap, cs
222+
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
245245
lb = varmap_to_vars(dvs .=> lb, dvs; defaults = defs, tofloat = false, use_union)
246246
ub = varmap_to_vars(dvs .=> ub, dvs; defaults = defs, tofloat = false, use_union)
247247

248-
if !isnothing(lb) && all(lb .== -Inf)
248+
if !isnothing(lb) && all(lb .== -Inf) && !isnothing(ub) && all(ub .== Inf)
249249
lb = nothing
250-
end
251-
if !isnothing(ub) && all(ub .== Inf)
252250
ub = nothing
253251
end
254252

@@ -301,16 +299,18 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
301299
cons_expr = toexpr.(subs_constants(constraints(cons_sys)))
302300
rep_pars_vals!.(cons_expr, Ref(pairs_arr))
303301

304-
if isnothing(lcons) && isnothing(ucons) # use the symbolically specified bounds
302+
if !haskey(kwargs, :lcons) && !haskey(kwargs, :ucons) # use the symbolically specified bounds
305303
lcons = lcons_
306304
ucons = ucons_
307305
else # use the user supplied constraints bounds
308-
xor(isnothing(lcons), isnothing(lcons)) &&
309-
throw(ArgumentError("Expected both `lcons` and `lcons` to be supplied"))
310-
!isnothing(lcons) && length(lcons) != length(cstr) &&
306+
haskey(kwargs, :lcons) && haskey(kwargs, :ucons) &&
307+
throw(ArgumentError("Expected both `ucons` and `lcons` to be supplied"))
308+
haskey(kwargs, :lcons) && length(kwargs[:lcons]) != length(cstr) &&
311309
throw(ArgumentError("Expected `lcons` to be of the same length as the vector of constraints"))
312-
!isnothing(ucons) && length(ucons) != length(cstr) &&
310+
haskey(kwargs, :ucons) && length(kwargs[:ucons]) != length(cstr) &&
313311
throw(ArgumentError("Expected `ucons` to be of the same length as the vector of constraints"))
312+
lcons = haskey(kwargs, :lcons)
313+
ucons = haskey(kwargs, :ucons)
314314
end
315315

316316
if sparse
@@ -417,10 +417,8 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
417417
lb = varmap_to_vars(dvs .=> lb, dvs; defaults = defs, tofloat = false, use_union)
418418
ub = varmap_to_vars(dvs .=> ub, dvs; defaults = defs, tofloat = false, use_union)
419419

420-
if !isnothing(lb) && all(lb .== -Inf)
420+
if !isnothing(lb) && all(lb .== -Inf) && !isnothing(ub) && all(ub .== Inf)
421421
lb = nothing
422-
end
423-
if !isnothing(ub) && all(ub .== Inf)
424422
ub = nothing
425423
end
426424

@@ -468,16 +466,18 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
468466
cons_expr = toexpr.(subs_constants(constraints(cons_sys)))
469467
rep_pars_vals!.(cons_expr, Ref(pairs_arr))
470468

471-
if isnothing(lcons) && isnothing(ucons) # use the symbolically specified bounds
469+
if !haskey(kwargs, :lcons) && !haskey(kwargs, :ucons) # use the symbolically specified bounds
472470
lcons = lcons_
473471
ucons = ucons_
474472
else # use the user supplied constraints bounds
475-
xor(isnothing(lcons), isnothing(lcons)) &&
476-
throw(ArgumentError("Expected both `lcons` and `lcons` to be supplied"))
477-
!isnothing(lcons) && length(lcons) != length(cstr) &&
473+
haskey(kwargs, :lcons) && haskey(kwargs, :ucons) &&
474+
throw(ArgumentError("Expected both `ucons` and `lcons` to be supplied"))
475+
haskey(kwargs, :lcons) && length(kwargs[:lcons]) != length(cstr) &&
478476
throw(ArgumentError("Expected `lcons` to be of the same length as the vector of constraints"))
479-
!isnothing(ucons) && length(ucons) != length(cstr) &&
477+
haskey(kwargs, :ucons) && length(kwargs[:ucons]) != length(cstr) &&
480478
throw(ArgumentError("Expected `ucons` to be of the same length as the vector of constraints"))
479+
lcons = haskey(kwargs, :lcons)
480+
ucons = haskey(kwargs, :ucons)
481481
end
482482

483483
if sparse

src/utils.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,15 +597,19 @@ function empty_substitutions(sys)
597597
isnothing(subs) || isempty(subs.deps)
598598
end
599599

600-
function get_substitutions_and_solved_states(sys; no_postprocess = false)
600+
function get_cmap(sys)
601601
#Inject substitutions for constants => values
602602
cs = collect_constants([get_eqs(sys); get_observed(sys)]) #ctrls? what else?
603603
if !empty_substitutions(sys)
604604
cs = [cs; collect_constants(get_substitutions(sys).subs)]
605605
end
606606
# Swap constants for their values
607607
cmap = map(x -> x ~ getdefault(x), cs)
608+
return cmap, cs
609+
end
608610

611+
function get_substitutions_and_solved_states(sys; no_postprocess = false)
612+
cmap, cs = get_cmap(sys)
609613
if empty_substitutions(sys) && isempty(cs)
610614
sol_states = Code.LazyState()
611615
pre = no_postprocess ? (ex -> ex) : get_postprocess_fbody(sys)

0 commit comments

Comments
 (0)