Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 8084faa

Browse files
Move iterator checking here and make symbolics stuff extension
1 parent 3038abf commit 8084faa

File tree

7 files changed

+149
-105
lines changed

7 files changed

+149
-105
lines changed

Project.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,27 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1717
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1818
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1919
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
20-
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
21-
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
22-
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2320

2421
[weakdeps]
2522
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2623
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
2724
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
25+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
26+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2827
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
2928
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
29+
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
3030
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3131

3232
[extensions]
3333
OptimizationEnzymeExt = "Enzyme"
3434
OptimizationFiniteDiffExt = "FiniteDiff"
3535
OptimizationForwardDiffExt = "ForwardDiff"
36+
OptimizationMLDataDevicesExt = "MLDataDevices"
37+
OptimizationMLUtilsExt = "MLUtils"
3638
OptimizationMTKExt = "ModelingToolkit"
3739
OptimizationReverseDiffExt = "ReverseDiff"
40+
OptimizationSymbolicAnalysisExt = "SymbolicAnalysis"
3841
OptimizationZygoteExt = "Zygote"
3942

4043
[compat]
@@ -56,8 +59,6 @@ SciMLBase = "2"
5659
SparseConnectivityTracer = "0.6"
5760
SparseMatrixColorings = "0.4"
5861
SymbolicAnalysis = "0.3"
59-
SymbolicIndexingInterface = "0.3"
60-
Symbolics = "5.12, 6"
6162
Zygote = "0.6.67"
6263
julia = "1.10"
6364

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OptimizationMLDataDevicesExt
2+
3+
using MLDataDevices
4+
using OptimizationBase
5+
6+
OptimizationBase.isa_dataiterator(::DeviceIterator) = true
7+
8+
end

ext/OptimizationMLUtilsExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OptimizationMLUtilsExt
2+
3+
using MLUtils
4+
using OptimizationBase
5+
6+
OptimizationBase.isa_dataiterator(::MLUtils.DataLoader) = true
7+
8+
end
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
module OptimizationSymbolicAnalysisExt
2+
3+
using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics
4+
5+
function OptimizationBase.symify_cache(f::OptimizationFunction{iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O, EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV}, prob) where {iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O, EX <: Nothing, CEX <: Nothing, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV}
6+
try
7+
vars = if prob.u0 isa Matrix
8+
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
9+
else
10+
ArrayInterface.restructure(
11+
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
12+
end
13+
params = if prob.p isa SciMLBase.NullParameters
14+
[]
15+
elseif prob.p isa MTK.MTKParameters
16+
[variable(, i) for i in eachindex(vcat(p...))]
17+
else
18+
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
19+
end
20+
21+
if prob.u0 isa Matrix
22+
vars = vars[1]
23+
end
24+
25+
obj_expr = f.f(vars, params)
26+
27+
if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
28+
lhs = Array{Symbolics.Num}(undef, num_cons)
29+
f.cons(lhs, vars)
30+
cons = Union{Equation, Inequality}[]
31+
32+
if !isnothing(prob.lcons)
33+
for i in 1:num_cons
34+
if !isinf(prob.lcons[i])
35+
if prob.lcons[i] != prob.ucons[i]
36+
push!(cons, prob.lcons[i] lhs[i])
37+
else
38+
push!(cons, lhs[i] ~ prob.ucons[i])
39+
end
40+
end
41+
end
42+
end
43+
44+
if !isnothing(prob.ucons)
45+
for i in 1:num_cons
46+
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
47+
push!(cons, lhs[i] prob.ucons[i])
48+
end
49+
end
50+
end
51+
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
52+
(isnothing(prob.ucons) || all(isinf, prob.ucons))
53+
throw(ArgumentError("Constraints passed have no proper bounds defined.
54+
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
55+
or pass the lower and upper bounds for inequality constraints."))
56+
end
57+
cons_expr = lhs
58+
elseif !isnothing(prob.f.cons)
59+
cons_expr = f.cons(vars, params)
60+
else
61+
cons_expr = nothing
62+
end
63+
catch err
64+
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
65+
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
66+
end
67+
return obj_expr, cons_expr
68+
end
69+
70+
function analysis(obj_expr, cons_expr)
71+
if obj_expr !== nothing
72+
obj_expr = obj_expr |> Symbolics.unwrap
73+
if manifold === nothing
74+
obj_res = analyze(obj_expr)
75+
else
76+
obj_res = analyze(obj_expr, manifold)
77+
end
78+
@info "Objective Euclidean curvature: $(obj_res.curvature)"
79+
if obj_res.gcurvature !== nothing
80+
@info "Objective Geodesic curvature: $(obj_res.gcurvature)"
81+
end
82+
end
83+
84+
if cons_expr !== nothing
85+
cons_expr = cons_expr .|> Symbolics.unwrap
86+
if manifold === nothing
87+
cons_res = analyze.(cons_expr)
88+
else
89+
cons_res = analyze.(cons_expr, Ref(manifold))
90+
end
91+
for i in 1:num_cons
92+
@info "Constraints Euclidean curvature: $(cons_res[i].curvature)"
93+
94+
if cons_res[i].gcurvature !== nothing
95+
@info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)"
96+
end
97+
end
98+
end
99+
100+
return obj_res, cons_res
101+
end
102+
103+
end

src/OptimizationBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Base.iterate(::NullData, i = 1) = nothing
3131
Base.length(::NullData) = 0
3232

3333
include("adtypes.jl")
34+
include("symify.jl")
3435
include("cache.jl")
3536
include("OptimizationDIExt.jl")
3637
include("OptimizationDISparseExt.jl")

src/cache.jl

Lines changed: 17 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import Symbolics: , ~
22

3+
isa_dataiterator(data) = false
4+
35
struct AnalysisResults
46
objective::Union{Nothing, AnalysisResult}
57
constraints::Union{Nothing, Vector{AnalysisResult}}
@@ -32,122 +34,37 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
3234
structural_analysis = false,
3335
manifold = nothing,
3436
kwargs...)
35-
reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p)
37+
38+
if isa_dataiterator(prob.p)
39+
reinit_cache = OptimizationBase.ReInitCache(prob.u0, iterate(prob.p)[1])
40+
reinit_cache_passedon = OptimizationBase.ReInitCache(prob.u0, prob.p)
41+
else
42+
reinit_cache = OptimizationBase.ReInitCache(prob.u0, iterate(prob.p)[1])
43+
reinit_cache_passedon = reinit_cache
44+
end
45+
3646
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
47+
3748
f = OptimizationBase.instantiate_function(
3849
prob.f, reinit_cache, prob.f.adtype, num_cons;
3950
g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt),
4051
hv = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt),
4152
fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt),
4253
cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))
4354

44-
if (f.sys === nothing ||
45-
f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) &&
46-
structural_analysis
55+
if structural_analysis
56+
obj_expr, cons_expr = symify_cache(f, prob)
4757
try
48-
vars = if prob.u0 isa Matrix
49-
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
50-
else
51-
ArrayInterface.restructure(
52-
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
53-
end
54-
params = if prob.p isa SciMLBase.NullParameters
55-
[]
56-
elseif prob.p isa MTK.MTKParameters
57-
[variable(, i) for i in eachindex(vcat(p...))]
58-
else
59-
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
60-
end
61-
62-
if prob.u0 isa Matrix
63-
vars = vars[1]
64-
end
65-
66-
obj_expr = f.f(vars, params)
67-
68-
if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
69-
lhs = Array{Symbolics.Num}(undef, num_cons)
70-
f.cons(lhs, vars)
71-
cons = Union{Equation, Inequality}[]
72-
73-
if !isnothing(prob.lcons)
74-
for i in 1:num_cons
75-
if !isinf(prob.lcons[i])
76-
if prob.lcons[i] != prob.ucons[i]
77-
push!(cons, prob.lcons[i] lhs[i])
78-
else
79-
push!(cons, lhs[i] ~ prob.ucons[i])
80-
end
81-
end
82-
end
83-
end
84-
85-
if !isnothing(prob.ucons)
86-
for i in 1:num_cons
87-
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
88-
push!(cons, lhs[i] prob.ucons[i])
89-
end
90-
end
91-
end
92-
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
93-
(isnothing(prob.ucons) || all(isinf, prob.ucons))
94-
throw(ArgumentError("Constraints passed have no proper bounds defined.
95-
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
96-
or pass the lower and upper bounds for inequality constraints."))
97-
end
98-
cons_expr = lhs
99-
elseif !isnothing(prob.f.cons)
100-
cons_expr = f.cons(vars, params)
101-
else
102-
cons_expr = nothing
103-
end
58+
obj_res, cons_res = analysis(obj_expr, cons_expr)
10459
catch err
105-
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
106-
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
107-
end
108-
else
109-
sys = f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing} ?
110-
nothing : f.sys
111-
obj_expr = f.expr
112-
cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs))
113-
end
114-
115-
if obj_expr !== nothing && structural_analysis
116-
obj_expr = obj_expr |> Symbolics.unwrap
117-
if manifold === nothing
118-
obj_res = analyze(obj_expr)
119-
else
120-
obj_res = analyze(obj_expr, manifold)
121-
end
122-
123-
@info "Objective Euclidean curvature: $(obj_res.curvature)"
124-
125-
if obj_res.gcurvature !== nothing
126-
@info "Objective Geodesic curvature: $(obj_res.gcurvature)"
60+
throw("Structural analysis requires SymbolicAnalysis.jl to be loaded, either add `using SymbolicAnalysis` to your script or set `structural_analysis = false`.")
12761
end
12862
else
12963
obj_res = nothing
130-
end
131-
132-
if cons_expr !== nothing && structural_analysis
133-
cons_expr = cons_expr .|> Symbolics.unwrap
134-
if manifold === nothing
135-
cons_res = analyze.(cons_expr)
136-
else
137-
cons_res = analyze.(cons_expr, Ref(manifold))
138-
end
139-
for i in 1:num_cons
140-
@info "Constraints Euclidean curvature: $(cons_res[i].curvature)"
141-
142-
if cons_res[i].gcurvature !== nothing
143-
@info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)"
144-
end
145-
end
146-
else
14764
cons_res = nothing
14865
end
14966

150-
return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons,
67+
return OptimizationCache(f, reinit_cache_passedon, prob.lb, prob.ub, prob.lcons,
15168
prob.ucons, prob.sense,
15269
opt, progress, callback, manifold, AnalysisResults(obj_res, cons_res),
15370
merge((; maxiters, maxtime, abstol, reltol),

src/symify.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
function symify_cache(f::OptimizationFunction, prob)
2+
obj_expr = f.expr
3+
cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs))
4+
5+
return obj_expr, cons_expr
6+
end

0 commit comments

Comments
 (0)