|
1 | 1 | import Symbolics: ≲, ~ |
2 | 2 |
|
| 3 | +isa_dataiterator(data) = false |
| 4 | + |
3 | 5 | struct AnalysisResults |
4 | 6 | objective::Union{Nothing, AnalysisResult} |
5 | 7 | constraints::Union{Nothing, Vector{AnalysisResult}} |
@@ -32,122 +34,37 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt; |
32 | 34 | structural_analysis = false, |
33 | 35 | manifold = nothing, |
34 | 36 | kwargs...) |
35 | | - reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p) |
| 37 | + |
| 38 | + if isa_dataiterator(prob.p) |
| 39 | + reinit_cache = OptimizationBase.ReInitCache(prob.u0, iterate(prob.p)[1]) |
| 40 | + reinit_cache_passedon = OptimizationBase.ReInitCache(prob.u0, prob.p) |
| 41 | + else |
| 42 | + reinit_cache = OptimizationBase.ReInitCache(prob.u0, iterate(prob.p)[1]) |
| 43 | + reinit_cache_passedon = reinit_cache |
| 44 | + end |
| 45 | + |
36 | 46 | num_cons = prob.ucons === nothing ? 0 : length(prob.ucons) |
| 47 | + |
37 | 48 | f = OptimizationBase.instantiate_function( |
38 | 49 | prob.f, reinit_cache, prob.f.adtype, num_cons; |
39 | 50 | g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), |
40 | 51 | hv = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt), |
41 | 52 | fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt), |
42 | 53 | cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt)) |
43 | 54 |
|
44 | | - if (f.sys === nothing || |
45 | | - f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) && |
46 | | - structural_analysis |
| 55 | + if structural_analysis |
| 56 | + obj_expr, cons_expr = symify_cache(f, prob) |
47 | 57 | try |
48 | | - vars = if prob.u0 isa Matrix |
49 | | - @variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)] |
50 | | - else |
51 | | - ArrayInterface.restructure( |
52 | | - prob.u0, [variable(:x, i) for i in eachindex(prob.u0)]) |
53 | | - end |
54 | | - params = if prob.p isa SciMLBase.NullParameters |
55 | | - [] |
56 | | - elseif prob.p isa MTK.MTKParameters |
57 | | - [variable(:α, i) for i in eachindex(vcat(p...))] |
58 | | - else |
59 | | - ArrayInterface.restructure(p, [variable(:α, i) for i in eachindex(p)]) |
60 | | - end |
61 | | - |
62 | | - if prob.u0 isa Matrix |
63 | | - vars = vars[1] |
64 | | - end |
65 | | - |
66 | | - obj_expr = f.f(vars, params) |
67 | | - |
68 | | - if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons) |
69 | | - lhs = Array{Symbolics.Num}(undef, num_cons) |
70 | | - f.cons(lhs, vars) |
71 | | - cons = Union{Equation, Inequality}[] |
72 | | - |
73 | | - if !isnothing(prob.lcons) |
74 | | - for i in 1:num_cons |
75 | | - if !isinf(prob.lcons[i]) |
76 | | - if prob.lcons[i] != prob.ucons[i] |
77 | | - push!(cons, prob.lcons[i] ≲ lhs[i]) |
78 | | - else |
79 | | - push!(cons, lhs[i] ~ prob.ucons[i]) |
80 | | - end |
81 | | - end |
82 | | - end |
83 | | - end |
84 | | - |
85 | | - if !isnothing(prob.ucons) |
86 | | - for i in 1:num_cons |
87 | | - if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i] |
88 | | - push!(cons, lhs[i] ≲ prob.ucons[i]) |
89 | | - end |
90 | | - end |
91 | | - end |
92 | | - if (isnothing(prob.lcons) || all(isinf, prob.lcons)) && |
93 | | - (isnothing(prob.ucons) || all(isinf, prob.ucons)) |
94 | | - throw(ArgumentError("Constraints passed have no proper bounds defined. |
95 | | - Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints |
96 | | - or pass the lower and upper bounds for inequality constraints.")) |
97 | | - end |
98 | | - cons_expr = lhs |
99 | | - elseif !isnothing(prob.f.cons) |
100 | | - cons_expr = f.cons(vars, params) |
101 | | - else |
102 | | - cons_expr = nothing |
103 | | - end |
| 58 | + obj_res, cons_res = analysis(obj_expr, cons_expr) |
104 | 59 | catch err |
105 | | - throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err. |
106 | | - Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions.")) |
107 | | - end |
108 | | - else |
109 | | - sys = f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing} ? |
110 | | - nothing : f.sys |
111 | | - obj_expr = f.expr |
112 | | - cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs)) |
113 | | - end |
114 | | - |
115 | | - if obj_expr !== nothing && structural_analysis |
116 | | - obj_expr = obj_expr |> Symbolics.unwrap |
117 | | - if manifold === nothing |
118 | | - obj_res = analyze(obj_expr) |
119 | | - else |
120 | | - obj_res = analyze(obj_expr, manifold) |
121 | | - end |
122 | | - |
123 | | - @info "Objective Euclidean curvature: $(obj_res.curvature)" |
124 | | - |
125 | | - if obj_res.gcurvature !== nothing |
126 | | - @info "Objective Geodesic curvature: $(obj_res.gcurvature)" |
| 60 | + throw("Structural analysis requires SymbolicAnalysis.jl to be loaded, either add `using SymbolicAnalysis` to your script or set `structural_analysis = false`.") |
127 | 61 | end |
128 | 62 | else |
129 | 63 | obj_res = nothing |
130 | | - end |
131 | | - |
132 | | - if cons_expr !== nothing && structural_analysis |
133 | | - cons_expr = cons_expr .|> Symbolics.unwrap |
134 | | - if manifold === nothing |
135 | | - cons_res = analyze.(cons_expr) |
136 | | - else |
137 | | - cons_res = analyze.(cons_expr, Ref(manifold)) |
138 | | - end |
139 | | - for i in 1:num_cons |
140 | | - @info "Constraints Euclidean curvature: $(cons_res[i].curvature)" |
141 | | - |
142 | | - if cons_res[i].gcurvature !== nothing |
143 | | - @info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)" |
144 | | - end |
145 | | - end |
146 | | - else |
147 | 64 | cons_res = nothing |
148 | 65 | end |
149 | 66 |
|
150 | | - return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons, |
| 67 | + return OptimizationCache(f, reinit_cache_passedon, prob.lb, prob.ub, prob.lcons, |
151 | 68 | prob.ucons, prob.sense, |
152 | 69 | opt, progress, callback, manifold, AnalysisResults(obj_res, cons_res), |
153 | 70 | merge((; maxiters, maxtime, abstol, reltol), |
|
0 commit comments