Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 91d3d93

Browse files
fix symbolic analysis dispatches
1 parent e5e8543 commit 91d3d93

File tree

4 files changed

+67
-65
lines changed

4 files changed

+67
-65
lines changed
Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,82 @@
11
module OptimizationSymbolicAnalysisExt
22

3-
using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics
3+
using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics, OptimizationBase.ArrayInterface
44
using SymbolicAnalysis: AnalysisResult
5-
import Symbolics: variable, Equation, Inequality, unwrap, @variables
5+
import SymbolicAnalysis.Symbolics: variable, Equation, Inequality, unwrap, @variables
66

77
function OptimizationBase.symify_cache(
88
f::OptimizationFunction{iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP,
99
CJP, CHP, O, EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV},
10-
prob) where {iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
10+
prob, manifold) where {iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
1111
EX <: Nothing, CEX <: Nothing, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV}
12-
try
13-
vars = if prob.u0 isa Matrix
14-
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
15-
else
16-
ArrayInterface.restructure(
17-
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
18-
end
19-
params = if prob.p isa SciMLBase.NullParameters
20-
[]
21-
elseif prob.p isa MTK.MTKParameters
22-
[variable(, i) for i in eachindex(vcat(p...))]
23-
else
24-
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
25-
end
12+
obj_expr = f.expr
13+
cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs))
2614

27-
if prob.u0 isa Matrix
28-
vars = vars[1]
29-
end
15+
16+
if obj_expr === nothing || cons_expr === nothing
17+
try
18+
vars = if prob.u0 isa Matrix
19+
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
20+
else
21+
ArrayInterface.restructure(
22+
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
23+
end
24+
params = if prob.p isa SciMLBase.NullParameters
25+
[]
26+
elseif prob.p isa MTK.MTKParameters
27+
[variable(, i) for i in eachindex(vcat(p...))]
28+
else
29+
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
30+
end
3031

31-
obj_expr = f.f(vars, params)
32+
if prob.u0 isa Matrix
33+
vars = vars[1]
34+
end
3235

33-
if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
34-
lhs = Array{Symbolics.Num}(undef, num_cons)
35-
f.cons(lhs, vars)
36-
cons = Union{Equation, Inequality}[]
36+
if obj_expr === nothing
37+
obj_expr = f.f(vars, params)
38+
end
3739

38-
if !isnothing(prob.lcons)
39-
for i in 1:num_cons
40-
if !isinf(prob.lcons[i])
41-
if prob.lcons[i] != prob.ucons[i]
42-
push!(cons, prob.lcons[i] lhs[i])
43-
else
44-
push!(cons, lhs[i] ~ prob.ucons[i])
40+
if cons_expr === nothing && SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
41+
lhs = Array{Symbolics.Num}(undef, num_cons)
42+
f.cons(lhs, vars)
43+
cons = Union{Equation, Inequality}[]
44+
45+
if !isnothing(prob.lcons)
46+
for i in 1:num_cons
47+
if !isinf(prob.lcons[i])
48+
if prob.lcons[i] != prob.ucons[i]
49+
push!(cons, prob.lcons[i] lhs[i])
50+
else
51+
push!(cons, lhs[i] ~ prob.ucons[i])
52+
end
4553
end
4654
end
4755
end
48-
end
4956

50-
if !isnothing(prob.ucons)
51-
for i in 1:num_cons
52-
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
53-
push!(cons, lhs[i] prob.ucons[i])
57+
if !isnothing(prob.ucons)
58+
for i in 1:num_cons
59+
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
60+
push!(cons, lhs[i] prob.ucons[i])
61+
end
5462
end
5563
end
64+
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
65+
(isnothing(prob.ucons) || all(isinf, prob.ucons))
66+
throw(ArgumentError("Constraints passed have no proper bounds defined.
67+
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
68+
or pass the lower and upper bounds for inequality constraints."))
69+
end
70+
cons_expr = lhs
71+
elseif cons_expr === nothing && !isnothing(prob.f.cons)
72+
cons_expr = f.cons(vars, params)
5673
end
57-
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
58-
(isnothing(prob.ucons) || all(isinf, prob.ucons))
59-
throw(ArgumentError("Constraints passed have no proper bounds defined.
60-
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
61-
or pass the lower and upper bounds for inequality constraints."))
62-
end
63-
cons_expr = lhs
64-
elseif !isnothing(prob.f.cons)
65-
cons_expr = f.cons(vars, params)
66-
else
67-
cons_expr = nothing
74+
catch err
75+
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
76+
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
6877
end
69-
catch err
70-
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
71-
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
7278
end
73-
return obj_expr, cons_expr
74-
end
7579

76-
function analysis(obj_expr, cons_expr)
7780
if obj_expr !== nothing
7881
obj_expr = obj_expr |> Symbolics.unwrap
7982
if manifold === nothing
@@ -85,6 +88,8 @@ function analysis(obj_expr, cons_expr)
8588
if obj_res.gcurvature !== nothing
8689
@info "Objective Geodesic curvature: $(obj_res.gcurvature)"
8790
end
91+
else
92+
obj_res = nothing
8893
end
8994

9095
if cons_expr !== nothing
@@ -101,9 +106,12 @@ function analysis(obj_expr, cons_expr)
101106
@info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)"
102107
end
103108
end
109+
else
110+
cons_res = nothing
104111
end
105112

106113
return obj_res, cons_res
107114
end
108115

116+
109117
end

src/cache.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,7 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
5050
cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))
5151

5252
if structural_analysis
53-
obj_expr, cons_expr = symify_cache(f, prob)
54-
try
55-
obj_res, cons_res = analysis(obj_expr, cons_expr)
56-
catch err
57-
throw("Structural analysis requires SymbolicAnalysis.jl to be loaded, either add `using SymbolicAnalysis` to your script or set `structural_analysis = false`.")
58-
end
53+
obj_res, cons_res = symify_cache(f, prob, manifold)
5954
else
6055
obj_res = nothing
6156
cons_res = nothing

src/symify.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
function symify_cache(f::OptimizationFunction, prob)
2-
obj_expr = f.expr
3-
cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs))
4-
5-
return obj_expr, cons_expr
1+
function symify_cache(f::OptimizationFunction, prob, manifold)
2+
throw("Structural analysis requires SymbolicAnalysis.jl to be loaded, either add `using SymbolicAnalysis` to your script or set `structural_analysis = false`.")
63
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ using Test
33

44
@testset "OptimizationBase.jl" begin
55
include("adtests.jl")
6+
include("cvxtest.jl")
7+
include("matrixvalued.jl")
68
end

0 commit comments

Comments
 (0)