Skip to content

Commit cff02c9

Browse files
refactor: use build_function_wrapper in OptimizationProblem codegen
1 parent f28df16 commit cff02c9

File tree

2 files changed

+19
-45
lines changed

2 files changed

+19
-45
lines changed

src/systems/optimization/constraints_system.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,10 @@ end
172172

173173
function generate_jacobian(
174174
sys::ConstraintsSystem, vs = unknowns(sys), ps = parameters(sys);
175-
sparse = false, simplify = false, wrap_code = identity, kwargs...)
175+
sparse = false, simplify = false, kwargs...)
176176
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
177177
p = reorder_parameters(sys, ps)
178-
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs = vs, ps) .∘
179-
wrap_parameter_dependencies(sys, false)
180-
return build_function(jac, vs, p...; wrap_code, kwargs...)
178+
return build_function_wrapper(sys, jac, vs, p...; kwargs...)
181179
end
182180

183181
function calculate_hessian(sys::ConstraintsSystem; sparse = false, simplify = false)
@@ -193,25 +191,18 @@ end
193191

194192
function generate_hessian(
195193
sys::ConstraintsSystem, vs = unknowns(sys), ps = parameters(sys);
196-
sparse = false, simplify = false, wrap_code = identity, kwargs...)
194+
sparse = false, simplify = false, kwargs...)
197195
hess = calculate_hessian(sys, sparse = sparse, simplify = simplify)
198196
p = reorder_parameters(sys, ps)
199-
wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘
200-
wrap_parameter_dependencies(sys, false)
201-
return build_function(hess, vs, p...; wrap_code, kwargs...)
197+
return build_function_wrapper(sys, hess, vs, p...; kwargs...)
202198
end
203199

204200
function generate_function(sys::ConstraintsSystem, dvs = unknowns(sys),
205201
ps = parameters(sys);
206-
wrap_code = identity,
207202
kwargs...)
208203
lhss = generate_canonical_form_lhss(sys)
209-
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
210204
p = reorder_parameters(sys, value.(ps))
211-
wrap_code = wrap_code .∘ wrap_array_vars(sys, lhss; dvs, ps) .∘
212-
wrap_parameter_dependencies(sys, false)
213-
func = build_function(lhss, value.(dvs), p...; postprocess_fbody = pre,
214-
states = sol_states, wrap_code, kwargs...)
205+
func = build_function_wrapper(sys, lhss, value.(dvs), p...; kwargs...)
215206

216207
cstr = constraints(sys)
217208
lcons = fill(-Inf, length(cstr))

src/systems/optimization/optimizationsystem.jl

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,10 @@ function calculate_gradient(sys::OptimizationSystem)
194194
end
195195

196196
function generate_gradient(sys::OptimizationSystem, vs = unknowns(sys),
197-
ps = parameters(sys);
198-
wrap_code = identity,
199-
kwargs...)
197+
ps = parameters(sys); kwargs...)
200198
grad = calculate_gradient(sys)
201-
pre = get_preprocess_constants(grad)
202199
p = reorder_parameters(sys, ps)
203-
wrap_code = wrap_code .∘ wrap_array_vars(sys, grad; dvs = vs, ps) .∘
204-
wrap_parameter_dependencies(sys, !(grad isa AbstractArray))
205-
return build_function(grad, vs, p...; postprocess_fbody = pre, wrap_code,
206-
kwargs...)
200+
return build_function_wrapper(sys, grad, vs, p...; kwargs...)
207201
end
208202

209203
function calculate_hessian(sys::OptimizationSystem)
@@ -212,34 +206,22 @@ end
212206

213207
function generate_hessian(
214208
sys::OptimizationSystem, vs = unknowns(sys), ps = parameters(sys);
215-
sparse = false, wrap_code = identity, kwargs...)
209+
sparse = false, kwargs...)
216210
if sparse
217211
hess = sparsehessian(objective(sys), unknowns(sys))
218212
else
219213
hess = calculate_hessian(sys)
220214
end
221-
pre = get_preprocess_constants(hess)
222215
p = reorder_parameters(sys, ps)
223-
wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘
224-
wrap_parameter_dependencies(sys, false)
225-
return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code,
226-
kwargs...)
216+
return build_function_wrapper(sys, hess, vs, p...; kwargs...)
227217
end
228218

229219
function generate_function(sys::OptimizationSystem, vs = unknowns(sys),
230220
ps = parameters(sys);
231-
wrap_code = identity,
232-
kwargs...)
233-
eqs = subs_constants(objective(sys))
234-
p = if has_index_cache(sys)
235-
reorder_parameters(get_index_cache(sys), ps)
236-
else
237-
(ps,)
238-
end
239-
wrap_code = wrap_code .∘ wrap_array_vars(sys, eqs; dvs = vs, ps) .∘
240-
wrap_parameter_dependencies(sys, !(eqs isa AbstractArray))
241-
return build_function(eqs, vs, p...; wrap_code,
242221
kwargs...)
222+
eqs = objective(sys)
223+
p = reorder_parameters(sys, ps)
224+
return build_function_wrapper(sys, eqs, vs, p...; kwargs...)
243225
end
244226

245227
function namespace_objective(sys::AbstractSystem)
@@ -368,7 +350,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
368350
f = let _f = eval_or_rgf(
369351
generate_function(
370352
sys, checkbounds = checkbounds, linenumbers = linenumbers,
371-
expression = Val{true});
353+
expression = Val{true}, wrap_mtkparameters = false);
372354
eval_expression,
373355
eval_module)
374356
__f(u, p) = _f(u, p)
@@ -382,7 +364,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
382364
generate_gradient(
383365
sys, checkbounds = checkbounds,
384366
linenumbers = linenumbers,
385-
parallel = parallel, expression = Val{true});
367+
parallel = parallel, expression = Val{true},
368+
wrap_mtkparameters = false);
386369
eval_expression,
387370
eval_module)
388371
_grad(u, p) = grad_oop(u, p)
@@ -401,7 +384,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
401384
sys, checkbounds = checkbounds,
402385
linenumbers = linenumbers,
403386
sparse = sparse, parallel = parallel,
404-
expression = Val{true});
387+
expression = Val{true}, wrap_mtkparameters = false);
405388
eval_expression,
406389
eval_module)
407390
_hess(u, p) = hess_oop(u, p)
@@ -427,7 +410,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
427410
cons_sys = complete(cons_sys)
428411
cons, lcons_, ucons_ = generate_function(cons_sys, checkbounds = checkbounds,
429412
linenumbers = linenumbers,
430-
expression = Val{true})
413+
expression = Val{true}; wrap_mtkparameters = false)
431414
cons = let (cons_oop, cons_iip) = eval_or_rgf.(cons; eval_expression, eval_module)
432415
_cons(u, p) = cons_oop(u, p)
433416
_cons(resid, u, p) = cons_iip(resid, u, p)
@@ -440,7 +423,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
440423
checkbounds = checkbounds,
441424
linenumbers = linenumbers,
442425
parallel = parallel, expression = Val{true},
443-
sparse = cons_sparse);
426+
sparse = cons_sparse, wrap_mtkparameters = false);
444427
eval_expression,
445428
eval_module)
446429
_cons_j(u, p) = cons_jac_oop(u, p)
@@ -458,7 +441,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
458441
cons_sys, checkbounds = checkbounds,
459442
linenumbers = linenumbers,
460443
sparse = cons_sparse, parallel = parallel,
461-
expression = Val{true});
444+
expression = Val{true}, wrap_mtkparameters = false);
462445
eval_expression,
463446
eval_module)
464447
_cons_h(u, p) = cons_hess_oop(u, p)

0 commit comments

Comments
 (0)