Skip to content

Commit 519038e

Browse files
Merge pull request #3209 from AayushSabharwal/as/array-bounds
feat: add support for bounds of array variables
2 parents 1b54706 + 6dbc375 commit 519038e

File tree

5 files changed

+119
-7
lines changed

5 files changed

+119
-7
lines changed

docs/src/basics/Variable_metadata.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,44 @@ hasbounds(u)
8686
getbounds(u)
8787
```
8888

89+
Bounds can also be specified for array variables. A scalar array bound is applied to each
90+
element of the array. A bound may also be specified as an array, in which case the size of
91+
the array must match the size of the symbolic variable.
92+
93+
```@example metadata
94+
@variables x[1:2, 1:2] [bounds = (-1, 1)]
95+
hasbounds(x)
96+
```
97+
98+
```@example metadata
99+
getbounds(x)
100+
```
101+
102+
```@example metadata
103+
getbounds(x[1, 1])
104+
```
105+
106+
```@example metadata
107+
getbounds(x[1:2, 1])
108+
```
109+
110+
```@example metadata
111+
@variables x[1:2] [bounds = (-Inf, [1.0, Inf])]
112+
hasbounds(x)
113+
```
114+
115+
```@example metadata
116+
getbounds(x)
117+
```
118+
119+
```@example metadata
120+
getbounds(x[2])
121+
```
122+
123+
```@example metadata
124+
hasbounds(x[2])
125+
```
126+
89127
## Guess
90128

91129
Specify an initial guess for custom initial conditions of an `ODESystem`.

src/systems/optimization/optimizationsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ function OptimizationSystem(op, unknowns, ps;
101101
gui_metadata = nothing)
102102
name === nothing &&
103103
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
104-
constraints = value.(scalarize(constraints))
105-
unknowns′ = value.(scalarize(unknowns))
104+
constraints = value.(reduce(vcat, scalarize(constraints); init = []))
105+
unknowns′ = value.(reduce(vcat, scalarize(unknowns); init = []))
106106
ps′ = value.(ps)
107107
op′ = value(scalarize(op))
108108

src/variables.jl

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ end
254254
## Bounds ======================================================================
255255
struct VariableBounds end
256256
Symbolics.option_to_metadata_type(::Val{:bounds}) = VariableBounds
257-
getbounds(x::Num) = getbounds(Symbolics.unwrap(x))
258257

259258
"""
260259
getbounds(x)
@@ -266,10 +265,35 @@ Create parameters with bounds like this
266265
@parameters p [bounds=(-1, 1)]
267266
```
268267
"""
269-
function getbounds(x)
268+
function getbounds(x::Union{Num, Symbolics.Arr, SymbolicUtils.Symbolic})
269+
x = unwrap(x)
270270
p = Symbolics.getparent(x, nothing)
271-
p === nothing || (x = p)
272-
Symbolics.getmetadata(x, VariableBounds, (-Inf, Inf))
271+
if p === nothing
272+
bounds = Symbolics.getmetadata(x, VariableBounds, (-Inf, Inf))
273+
if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown()
274+
bounds = map(bounds) do b
275+
b isa AbstractArray && return b
276+
return fill(b, size(x))
277+
end
278+
end
279+
else
280+
# if we reached here, `x` is the result of calling `getindex`
281+
bounds = @something Symbolics.getmetadata(x, VariableBounds, nothing) getbounds(p)
282+
idxs = arguments(x)[2:end]
283+
bounds = map(bounds) do b
284+
if b isa AbstractArray
285+
if Symbolics.shape(p) != Symbolics.Unknown() && size(p) != size(b)
286+
throw(DimensionMismatch("Expected array variable $p with shape $(size(p)) to have bounds of identical size. Found $bounds of size $(size(bounds))."))
287+
end
288+
return b[idxs...]
289+
elseif symbolic_type(x) == ArraySymbolic()
290+
return fill(b, size(x))
291+
else
292+
return b
293+
end
294+
end
295+
end
296+
return bounds
273297
end
274298

275299
"""
@@ -280,7 +304,7 @@ See also [`getbounds`](@ref).
280304
"""
281305
function hasbounds(x)
282306
b = getbounds(x)
283-
isfinite(b[1]) || isfinite(b[2])
307+
any(isfinite.(b[1]) .|| isfinite.(b[2]))
284308
end
285309

286310
## Disturbance =================================================================

test/optimizationsystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,4 +358,13 @@ end
358358
@mtkbuild sys = OptimizationSystem(obj, [x, y, z], []; constraints = cons)
359359
@test is_variable(sys, z)
360360
@test !is_variable(sys, y)
361+
362+
@variables x[1:3] [bounds = ([-Inf, -1.0, -2.0], [Inf, 1.0, 2.0])]
363+
obj = x[1]^2 + x[2]^2 + x[3]^2
364+
cons = [x[2] ~ 2x[1] + 3, x[3] ~ x[1] + x[2]]
365+
@mtkbuild sys = OptimizationSystem(obj, [x], []; constraints = cons)
366+
@test length(unknowns(sys)) == 2
367+
@test !is_variable(sys, x[1])
368+
@test is_variable(sys, x[2])
369+
@test is_variable(sys, x[3])
361370
end

test/test_variable_metadata.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,47 @@ using ModelingToolkit
1010
@test !hasbounds(y)
1111
@test !haskey(ModelingToolkit.dump_variable_metadata(y), :bounds)
1212

13+
@variables y[1:3]
14+
@test !hasbounds(y)
15+
@test getbounds(y)[1] == [-Inf, -Inf, -Inf]
16+
@test getbounds(y)[2] == [Inf, Inf, Inf]
17+
for i in eachindex(y)
18+
@test !hasbounds(y[i])
19+
b = getbounds(y[i])
20+
@test b[1] == -Inf && b[2] == Inf
21+
end
22+
23+
@variables y[1:3] [bounds = (-1, 1)]
24+
@test hasbounds(y)
25+
@test getbounds(y)[1] == -ones(3)
26+
@test getbounds(y)[2] == ones(3)
27+
for i in eachindex(y)
28+
@test hasbounds(y[i])
29+
b = getbounds(y[i])
30+
@test b[1] == -1.0 && b[2] == 1.0
31+
end
32+
@test getbounds(y[1:2])[1] == -ones(2)
33+
@test getbounds(y[1:2])[2] == ones(2)
34+
35+
@variables y[1:2, 1:2] [bounds = (-1, [1.0 Inf; 2.0 3.0])]
36+
@test hasbounds(y)
37+
@test getbounds(y)[1] == [-1 -1; -1 -1]
38+
@test getbounds(y)[2] == [1.0 Inf; 2.0 3.0]
39+
for i in eachindex(y)
40+
@test hasbounds(y[i])
41+
b = getbounds(y[i])
42+
@test b[1] == -1 && b[2] == [1.0 Inf; 2.0 3.0][i]
43+
end
44+
45+
@variables y[1:2] [bounds = (-Inf, [1.0, Inf])]
46+
@test hasbounds(y)
47+
@test getbounds(y)[1] == [-Inf, -Inf]
48+
@test getbounds(y)[2] == [1.0, Inf]
49+
@test hasbounds(y[1])
50+
@test getbounds(y[1]) == (-Inf, 1.0)
51+
@test !hasbounds(y[2])
52+
@test getbounds(y[2]) == (-Inf, Inf)
53+
1354
# Guess
1455
@variables y [guess = 0]
1556
@test getguess(y) === 0

0 commit comments

Comments
 (0)