Skip to content

Commit 765d4c0

Browse files
Merge pull request #3844 from AayushSabharwal/as/fix-optfn
fix: fix `OptimizationFunction` generation of exprs
2 parents 647c9f9 + a4b96d3 commit 765d4c0

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

src/problems/optimizationproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ function SciMLBase.OptimizationFunction{iip}(sys::System;
5858
else
5959
_cons_h = cons_hess_prototype = nothing
6060
end
61-
cons_expr = cstr
61+
cons_expr = Code.toexpr.(expand.([eq.lhs for eq in Symbolics.canonical_form.(cstr)]))
6262
end
6363

64-
obj_expr = cost(sys)
64+
obj_expr = Code.toexpr(expand(cost(sys)))
6565

6666
observedfun = ObservedFunctionCache(
6767
sys; expression, eval_expression, eval_module, checkbounds, cse)

test/initializationsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ end
12931293
@test SciMLBase.successful_retcode(solve(prob))
12941294

12951295
seta = setsym_oop(prob, [a])
1296-
(newu0, newp) = seta(prob, ForwardDiff.Dual{ForwardDiff.Tag{:tag, Float64}}.([1.0], 1))
1296+
(newu0, newp) = seta(prob, ForwardDiff.Dual{ForwardDiff.Tag{:tag, Float64}}.([1.0], 0))
12971297
newprob = remake(prob, u0 = newu0, p = newp)
12981298

12991299
@test SciMLBase.successful_retcode(solve(newprob))

test/optimizationsystem.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ end
6464
sys = complete(sys)
6565
prob = OptimizationProblem(sys, [x => 0.0, y => 0.0, a => 1.0, b => 1.0],
6666
grad = true, hess = true, cons_j = true, cons_h = true)
67+
@test prob.f.cons_expr isa Vector{Expr}
68+
@test prob.f.expr isa Expr
6769
@test prob.f.sys === sys
6870
sol = solve(prob, IPNewton())
6971
@test sol.objective < 1.0
@@ -98,10 +100,10 @@ end
98100

99101
prob = OptimizationProblem(sys, [x => 0.0, y => 0.0, z => 0.0, a => 1.0, b => 1.0],
100102
grad = false, hess = false, cons_j = false, cons_h = false)
101-
@test_broken sol = solve(prob, AmplNLWriter.Optimizer(Ipopt_jll.amplexe))
102-
@test_skip sol.objective < 1.0
103-
@test_skip sol.u[0.808, -0.064] atol=1e-3
104-
@test_skip sol[x]^2 + sol[y]^2 1.0
103+
sol = solve(prob, AmplNLWriter.Optimizer(Ipopt_jll.amplexe))
104+
@test sol.objective < 1.0
105+
@test_broken sol.u[0.808, -0.064] atol=1e-3
106+
@test_broken sol[x]^2 + sol[y]^2 1.0
105107
end
106108

107109
@testset "rosenbrock" begin
@@ -289,9 +291,8 @@ end
289291
sys = complete(sys)
290292

291293
prob = OptimizationProblem(sys, [x => 0.0, y => 0.0, a => 1.0, b => 100.0])
292-
@test prob.f.expr isa Symbolics.Symbolic
293-
@test all(prob.f.cons_expr[i].lhs isa Symbolics.Symbolic
294-
for i in 1:length(prob.f.cons_expr))
294+
@test prob.f.expr isa Expr
295+
@test all(x -> x isa Expr, prob.f.cons_expr)
295296
end
296297

297298
@testset "Derivatives, iip and oop" begin

0 commit comments

Comments
 (0)