Skip to content

Commit f8ba227

Browse files
Merge pull request #304 from SciML/mtkexpr
Handle prototype typing and expression clean up in `NoAD`
2 parents ecb5541 + 385deb3 commit f8ba227

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

src/function/function.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
2+
function symbolify(e::Expr)
3+
if !(e.args[1] isa Symbol)
4+
e.args[1] = Symbol(e.args[1])
5+
end
6+
symbolify.(e.args)
7+
return e
8+
end
9+
10+
function symbolify(e)
11+
return e
12+
end
13+
14+
function rep_pars_vals!(e::Expr, p)
15+
rep_pars_vals!.(e.args, Ref(p))
16+
replace!(e.args, p...)
17+
end
18+
19+
function rep_pars_vals!(e, p) end
20+
121
"""
222
instantiate_function(f, x, ::AbstractADType, p, num_cons = 0)::OptimizationFunction
323
@@ -30,8 +50,14 @@ function instantiate_function(f, x, ::AbstractADType, p, num_cons = 0)
3050
cons = f.cons === nothing ? nothing : (x)->f.cons(x,p)
3151
cons_j = f.cons_j === nothing ? nothing : (res,x)->f.cons_j(res,x,p)
3252
cons_h = f.cons_h === nothing ? nothing : (res,x)->f.cons_h(res,x,p)
53+
hess_prototype = f.hess_prototype === nothing ? nothing : convert.(eltype(x), f.hess_prototype)
54+
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : convert.(eltype(x), f.cons_jac_prototype)
55+
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : convert.(eltype(x), f.cons_hess_prototype)
56+
expr = symbolify(f.expr)
57+
cons_expr = symbolify.(f.cons_expr)
3358

3459
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad=grad, hess=hess, hv=hv,
3560
cons=cons, cons_j=cons_j, cons_h=cons_h,
36-
hess_prototype=f.hess_prototype, cons_jac_prototype=f.cons_jac_prototype, cons_hess_prototype=f.cons_hess_prototype)
61+
hess_prototype=hess_prototype, cons_jac_prototype=cons_jac_prototype,
62+
cons_hess_prototype=cons_hess_prototype, expr=expr, cons_expr=cons_expr)
3763
end

src/function/mtk.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,6 @@ end
55

66
AutoModelingToolkit() = AutoModelingToolkit(false, false)
77

8-
function symbolify(e::Expr)
9-
if !(e.args[1] isa Symbol)
10-
e.args[1] = Symbol(e.args[1])
11-
end
12-
symbolify.(e.args)
13-
return e
14-
end
15-
16-
function symbolify(e)
17-
return e
18-
end
19-
20-
function rep_pars_vals!(e::Expr, p)
21-
rep_pars_vals!.(e.args, Ref(p))
22-
replace!(e.args, p...)
23-
end
24-
25-
function rep_pars_vals!(e, p) end
26-
278
function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons=0)
289
p = isnothing(p) ? SciMLBase.NullParameters() : p
2910
sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p))

0 commit comments

Comments
 (0)