Skip to content

Commit 1eda645

Browse files
Merge pull request #3243 from AayushSabharwal/as/optsys-discover-variables
feat: add automatic variable discovery for `OptimizationSystem`
2 parents bdb4c03 + 2885815 commit 1eda645

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

src/systems/optimization/optimizationsystem.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,33 @@ function OptimizationSystem(op, unknowns, ps;
144144
checks = checks)
145145
end
146146

147+
function OptimizationSystem(objective; constraints = [], kwargs...)
148+
allunknowns = OrderedSet()
149+
ps = OrderedSet()
150+
collect_vars!(allunknowns, ps, objective, nothing)
151+
for cons in constraints
152+
collect_vars!(allunknowns, ps, cons, nothing)
153+
end
154+
for ssys in get(kwargs, :systems, OptimizationSystem[])
155+
collect_scoped_vars!(allunknowns, ps, ssys, nothing)
156+
end
157+
new_ps = OrderedSet()
158+
for p in ps
159+
if iscall(p) && operation(p) === getindex
160+
par = arguments(p)[begin]
161+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
162+
all(par[i] in ps for i in eachindex(par))
163+
push!(new_ps, par)
164+
else
165+
push!(new_ps, p)
166+
end
167+
else
168+
push!(new_ps, p)
169+
end
170+
end
171+
return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
172+
end
173+
147174
function flatten(sys::OptimizationSystem)
148175
systems = get_systems(sys)
149176
isempty(systems) && return sys

src/utils.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,15 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif
516516
end
517517
end
518518
end
519+
if has_constraints(sys)
520+
for eq in get_constraints(sys)
521+
eqtype_supports_collect_vars(eq) || continue
522+
collect_vars!(unknowns, parameters, eq, iv; depth, op)
523+
end
524+
end
525+
if has_op(sys)
526+
collect_vars!(unknowns, parameters, get_op(sys), iv; depth, op)
527+
end
519528
newdepth = depth == -1 ? depth : depth + 1
520529
for ssys in get_systems(sys)
521530
collect_scoped_vars!(unknowns, parameters, ssys, iv; depth = newdepth, op)
@@ -544,9 +553,10 @@ Can be dispatched by higher-level libraries to indicate support.
544553
"""
545554
eqtype_supports_collect_vars(eq) = false
546555
eqtype_supports_collect_vars(eq::Equation) = true
556+
eqtype_supports_collect_vars(eq::Inequality) = true
547557
eqtype_supports_collect_vars(eq::Pair) = true
548558

549-
function collect_vars!(unknowns, parameters, eq::Equation, iv;
559+
function collect_vars!(unknowns, parameters, eq::Union{Equation, Inequality}, iv;
550560
depth = 0, op = Differential)
551561
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
552562
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
@@ -559,6 +569,7 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ
559569
return nothing
560570
end
561571

572+
562573
function collect_var!(unknowns, parameters, var, iv; depth = 0)
563574
isequal(var, iv) && return nothing
564575
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing

test/optimizationsystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,14 @@ end
377377
prob = OptimizationProblem(sys, [x => 1.0], [p => 1.0, f => (x -> 2x)])
378378
@test abs(prob.f.cons(prob.u0, prob.p)[1]) 1.0
379379
end
380+
381+
@testset "Variable discovery" begin
382+
@variables x1 x2
383+
@parameters p1 p2
384+
@named sys1 = OptimizationSystem(x1^2; constraints = [p1 * x1 2.0])
385+
@named sys2 = OptimizationSystem(x2^2; constraints = [p2 * x2 2.0], systems = [sys1])
386+
@test isequal(only(unknowns(sys1)), x1)
387+
@test isequal(only(parameters(sys1)), p1)
388+
@test all(y -> any(x -> isequal(x, y), unknowns(sys2)), [x2, sys1.x1])
389+
@test all(y -> any(x -> isequal(x, y), parameters(sys2)), [p2, sys1.p1])
390+
end

0 commit comments

Comments
 (0)