Skip to content

Commit eedb658

Browse files
committed
Additions from algebraic GPU work
Major: - Added "convex_evaluator" and "all_evaluators" functions - Added methods for apply_transform with an Equation or Vector{Equation} as the input instead of an ODESystem - Added pull_vars to extract variable names from Symbolics expressions - Added shrink_eqs to shrink vectors of equations via substitution Minor: - Added a fix for variables like x[1] that are not dynamic functions of other variables such as "t". - Added support for ^2 operation - Corrected mid_expr to now work properly with Symbolics
1 parent 39ec973 commit eedb658

File tree

10 files changed

+1472
-31
lines changed

10 files changed

+1472
-31
lines changed

Manifest.toml

Lines changed: 1030 additions & 0 deletions
Large diffs are not rendered by default.

Project.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name = "SourceCodeMcCormick"
2+
uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960"
3+
authors = ["Robert Gottlieb <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
8+
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
9+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
10+
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
11+
12+
[compat]
13+
IfElse = "0.1.0 - 1.1.1"
14+
ModelingToolkit = "~8"
15+
SymbolicUtils = "~0.19"

src/SourceCodeMcCormick.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ module SourceCodeMcCormick
33

44
using ModelingToolkit
55
using SymbolicUtils.Code
6+
using IfElse
7+
using DocStringExtensions
8+
69
# TODO: Need to import Assignment and other stuff probably
710
# Check out the functionality in ModelingToolkit.jl and Symbolics.jl
811

@@ -14,7 +17,9 @@ function transform_rule end
1417

1518
export McCormickIntervalTransform
1619

17-
export apply_transform, extract_terms, genvar, genparam, get_name
20+
export apply_transform, extract_terms, genvar, genparam, get_name,
21+
factor!, binarize!, pull_vars, shrink_eqs, convex_evaluator,
22+
all_evaluators
1823

1924
include(joinpath(@__DIR__, "interval", "interval.jl"))
2025
include(joinpath(@__DIR__, "relaxation", "relaxation.jl"))

src/interval/interval.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,23 @@ end
2222
function var_names(::IntervalTransform, s::Real)
2323
return s, s
2424
end
25-
function var_names(::IntervalTransform, s::Term{Real, Nothing}) #Any terms like "Differential"
26-
if length(s.arguments)>1
27-
error("Multiple arguments not supported.")
28-
end
29-
if typeof(s.arguments[1])<:Term #then it has args
25+
function var_names(::IntervalTransform, s::Term{Real, Nothing}) #Any terms like "Differential", or "x[1]" (NOT x[1](t))
26+
if typeof(s.arguments[1])<:Term #then it has typical args like "x", "y", ...
3027
args = Symbol[]
3128
for i in s.arguments[1].arguments
3229
push!(args, get_name(i))
3330
end
3431
var = get_name(s.arguments[1])
3532
var_lo = genvar(Symbol(string(var)*"_lo"), args)
3633
var_hi = genvar(Symbol(string(var)*"_hi"), args)
37-
elseif typeof(s.arguments[1])<:Sym #Then it has no args
38-
var_lo = genparam(Symbol(string(s.arguments[1].name)*"_lo"))
39-
var_hi = genparam(Symbol(string(s.arguments[1].name)*"_hi"))
34+
elseif typeof(s.arguments[1])<:Sym #Then it has no typical args, i.e., x[1] has args Any[x, 1]
35+
if length(s.arguments)==1
36+
var_lo = genparam(Symbol(string(s.arguments[1].name)*"_lo"))
37+
var_hi = genparam(Symbol(string(s.arguments[1].name)*"_hi"))
38+
else
39+
var_lo = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_lo"))
40+
var_hi = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_hi"))
41+
end
4042
else
4143
error("Type of argument invalid")
4244
end
@@ -74,18 +76,13 @@ function translate_initial_conditions(::IntervalTransform, prob::ODESystem, new_
7476
end
7577

7678

77-
78-
# function var_names(::IntervalTransform, s::Number)
79-
# sL = s
80-
# sU = s
81-
# sL, sU
82-
# end
83-
8479
# Helper functions for navigating SymbolicUtils structures
8580
get_name(x::Sym{SymbolicUtils.FnType{Tuple{Any}, Real}, Nothing}) = x.name
8681

8782
"""
88-
Takes x[1,1] returns :x_1_1
83+
get_name
84+
85+
Take a Symbolic-type object such as `x[1,1]` and return a symbol like `:x_1_1`.
8986
"""
9087
function get_name(s::Term{SymbolicUtils.FnType{Tuple, Real}, Nothing})
9188
d = s.arguments

src/interval/rules.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11

2-
using IfElse
3-
42
# Transformation rules should have the general form listed below:
53
# Check IntervalArithmetic for valid interval bound rules (we won't be able
64
# to do correctly rounded stuff on GPUs but otherwise the operators should work out)
@@ -13,6 +11,11 @@ function transform_rule(::IntervalTransform, nothing, yL, yU, xL, xU)
1311
ru = Equation(yU, xU)
1412
return rl, ru
1513
end
14+
function transform_rule(::IntervalTransform, ::typeof(getindex), yL, yU, xL, xU)
15+
rl = Equation(yL, xL)
16+
ru = Equation(yU, xU)
17+
return rl, ru
18+
end
1619
function transform_rule(::IntervalTransform, ::typeof(exp), yL, yU, xL, xU)
1720
rl = Equation(yL, exp(xL))
1821
ru = Equation(yU, exp(xU))
@@ -63,5 +66,12 @@ function transform_rule(::IntervalTransform, ::typeof(max), zL, zU, xL, xU, yL,
6366
ru = Equation(zU, max(xU, yU))
6467
return rl, ru
6568
end
69+
function transform_rule(::IntervalTransform, ::typeof(^), zL, zU, xL, xU, yL, yU)
70+
~((typeof(yL) <: Int) || (typeof(yL) <: AbstractFloat)) && error("Symbolic exponents not currently supported.")
71+
~(yL == 2) && error("Exponents besides 2 not currently supported")
72+
rl = Equation(zL, max(min(xU, 0.0), xL)^2)
73+
ru = Equation(zU, IfElse.ifelse(xU < 0.0, xL, IfElse.ifelse(xL > 0, xU, IfElse.ifelse(abs(xL) >= abs(xU), xL, xU)))^2)
74+
return rl, ru
75+
end
6676

6777
# TODO: /, ^

src/relaxation/relaxation.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ function var_names(::McCormickTransform, s::Real)
1919
return s, s
2020
end
2121
function var_names(::McCormickTransform, s::Term{Real, Nothing}) #Any terms like "Differential"
22-
if length(s.arguments)>1
23-
error("Multiple arguments not supported.")
24-
end
2522
if typeof(s.arguments[1])<:Term #then it has args
2623
args = Symbol[]
2724
for i in s.arguments[1].arguments
@@ -31,8 +28,13 @@ function var_names(::McCormickTransform, s::Term{Real, Nothing}) #Any terms like
3128
var_cv = genvar(Symbol(string(var)*"_cv"), args)
3229
var_cc = genvar(Symbol(string(var)*"_cc"), args)
3330
elseif typeof(s.arguments[1])<:Sym #Then it has no args
34-
var_cv = genparam(Symbol(string(s.arguments[1].name)*"_cv"))
35-
var_cc = genparam(Symbol(string(s.arguments[1].name)*"_cc"))
31+
if length(s.arguments)==1
32+
var_cv = genparam(Symbol(string(s.arguments[1].name)*"_cv"))
33+
var_cc = genparam(Symbol(string(s.arguments[1].name)*"_cc"))
34+
else
35+
var_cv = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_cv"))
36+
var_cc = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_cc"))
37+
end
3638
else
3739
error("Type of argument invalid")
3840
end
@@ -105,7 +107,7 @@ end
105107
line_expr(x, xL, xU, zL, zU) = IfElse.ifelse(zU > zL, (zL*(xU - x) + zU*(x - xL))/(xU - xL), zU)
106108

107109
# A symbolic way of computing the mid of three numbers (returns IfElse block)
108-
mid_expr(a, b, c) = IfElse.ifelse((a < b) && (b < c), y, IfElse.ifelse((c < b) && (b < a), b,
109-
IfElse.ifelse((b < a) && (a < c), x, IfElse.ifelse((c < a) && (a < b), a, c))))
110+
mid_expr(a, b, c) = IfElse.ifelse(a < b, IfElse.ifelse(b < c, b, IfElse.ifelse(c < a, a, c)),
111+
IfElse.ifelse(c < b, b, IfElse.ifelse(a < c, a, c)))
110112

111113
include(joinpath(@__DIR__, "rules.jl"))

src/relaxation/rules.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
2-
using IfElse
31

42
# Note: McCormick transform is only the convex and concave portions of the transformation,
53
# so that the interval transforms (which are the same as for a regular Interval Transform)
@@ -24,6 +22,11 @@ end
2422
#=
2523
Unitary Rules
2624
=#
25+
function transform_rule(::McCormickTransform, ::typeof(getindex), yL, yU, ycv, ycc, xL, xU, xcv, xcc)
26+
rcv = Equation(ycv, xcv)
27+
rcc = Equation(ycc, xcc)
28+
return rcv, rcc
29+
end
2730
function transform_rule(::McCormickTransform, ::typeof(exp), yL, yU, ycv, ycc, xL, xU, xcv, xcc)
2831
mcv = mid_expr(xcv, xcc, xL)
2932
mcc = mid_expr(xcv, xcc, xU)
@@ -154,6 +157,17 @@ function transform_rule(::McCormickTransform, ::typeof(max), zL, zU, zcv, zcc, x
154157
return rcv, rcc
155158
end
156159

160+
function transform_rule(::McCormickTransform, ::typeof(^), zL, zU, zcv, zcc, xL, xU, xcv, xcc, yL, yU, ycv, ycc)
161+
~((typeof(yL) <: Int) || (typeof(yL) <: AbstractFloat)) && error("Symbolic exponents not currently supported.")
162+
~(yL == 2) && error("Exponents besides 2 not currently supported")
163+
mcv = mid_expr(xcv, xcc, max(min(xU, 0.0), xL))
164+
mcc = mid_expr(xcv, xcc, IfElse.ifelse(xU < 0.0, xL, IfElse.ifelse(xL > 0, xU, IfElse.ifelse(abs(xL) >= abs(xU), xL, xU))))
165+
rcv = Equation(zcv, mcv^2)
166+
rcc = Equation(zcc, (xL+xU)*mcc - xU*xL)
167+
return rcv, rcc
168+
end
169+
170+
157171
#=
158172
TODO: Add other operators. It's probably helpful to break the McCormick overload and McCormick + Interval Outputs
159173
into separate transform_rules since the coupling for the ODEs are one directional and potentially useful.

src/transform/factor.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
base_term(a::Any) = false
33
base_term(a::Term{Real, Base.ImmutableDict{DataType,Any}}) = true
4+
base_term(a::Term{Real, Nothing}) = true
45
base_term(a::Sym) = true
56
base_term(a::Real) = true
67

src/transform/transform.jl

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ function apply_transform(transform::T, prob::ODESystem) where T<:AbstractTransfo
161161
else
162162
index = findall(x -> isequal(x.rhs, eqn.rhs), equations)
163163
push!(equations, Equation(eqn.lhs, equations[index[1]].lhs))
164-
165164
end
166165
end
167166

@@ -215,3 +214,71 @@ function apply_transform(transform::T, prob::ODESystem) where T<:AbstractTransfo
215214

216215
return new_sys
217216
end
217+
218+
# Separate case for applying transform to only a set of equations
219+
function apply_transform(transform::T, eqn_vector::Vector{Equation}) where T<:AbstractTransform
220+
221+
# Factorize all model equations to generate a new set of equations
222+
equations = Equation[]
223+
for eqn in eqn_vector
224+
current = length(equations)
225+
factor!(eqn.rhs, eqs=equations)
226+
if length(equations) > current
227+
push!(equations, Equation(eqn.lhs, equations[end].rhs))
228+
deleteat!(equations, length(equations)-1)
229+
else
230+
index = findall(x -> isequal(x.rhs, eqn.rhs), equations)
231+
push!(equations, Equation(eqn.lhs, equations[index[1]].lhs))
232+
end
233+
end
234+
235+
# Apply transform rules to the factored equations to make the final equation set
236+
new_equations = Equation[]
237+
for a in equations
238+
zn = var_names(transform, zstr(a))
239+
xn = var_names(transform, xstr(a))
240+
if isone(arity(a))
241+
targs = (transform, op(a), zn..., xn...)
242+
else
243+
targs = (transform, op(a), zn..., xn..., var_names(transform, ystr(a))...)
244+
end
245+
new = transform_rule(targs...)
246+
for i in new
247+
push!(new_equations, i)
248+
end
249+
end
250+
251+
return new_equations
252+
end
253+
function apply_transform(transform::T, eqn::Equation) where T<:AbstractTransform
254+
255+
# Factorize the equations to generate a new set of equations
256+
equations = Equation[]
257+
current = 0
258+
factor!(eqn.rhs, eqs=equations)
259+
if length(equations) > current
260+
push!(equations, Equation(eqn.lhs, equations[end].rhs))
261+
deleteat!(equations, length(equations)-1)
262+
else
263+
index = findall(x -> isequal(x.rhs, eqn.rhs), equations)
264+
push!(equations, Equation(eqn.lhs, equations[index[1]].lhs))
265+
end
266+
267+
# Apply transform rules to the factored equations to make the final equation set
268+
new_equations = Equation[]
269+
for a in equations
270+
zn = var_names(transform, zstr(a))
271+
xn = var_names(transform, xstr(a))
272+
if isone(arity(a))
273+
targs = (transform, op(a), zn..., xn...)
274+
else
275+
targs = (transform, op(a), zn..., xn..., var_names(transform, ystr(a))...)
276+
end
277+
new = transform_rule(targs...)
278+
for i in new
279+
push!(new_equations, i)
280+
end
281+
end
282+
283+
return new_equations
284+
end

0 commit comments

Comments
 (0)