Skip to content

Commit 74afec2

Browse files
committed
Several massive performance improvements
- Fixed multiplication of (variable*Real) to not be nearly as expansive (previously, this resulted in a long IfElse block where all the possibilities ended up being identical) - Changed "shrink_eqs" to progressively remove already-used equations to speed up later substitutions - The "convex_evaluator" and "all_evaluators" functions now check if the highest-level operation of a given expression is addition, and if so, they calculate relaxations for each operand separately and add them together at the end.
1 parent 46543d7 commit 74afec2

File tree

6 files changed

+157
-63
lines changed

6 files changed

+157
-63
lines changed

src/interval/interval.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ function var_names(::IntervalTransform, s::Term{Real, Nothing}) #Any terms like
4040
var_hi = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_hi"))
4141
end
4242
else
43+
println("Term: $s")
44+
for arg in s.arguments
45+
@show arg
46+
@show typeof(arg)
47+
end
4348
error("Type of argument invalid")
4449
end
4550

src/interval/rules.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ function transform_rule(::IntervalTransform, ::typeof(-), zL, zU, xL, xU, yL, yU
3535
ru = Equation(zU, xU - yL)
3636
return rl, ru
3737
end
38+
function transform_rule(::IntervalTransform, ::typeof(*), zL, zU, xL, xU, yL::Real, yU::Real)
39+
rl = Equation(zL, IfElse.ifelse(yL >= 0.0, yL*xL, yU*xU))
40+
ru = Equation(zU, IfElse.ifelse(yL >= 0.0, yU*xU, yL*xL))
41+
return rl, ru
42+
end
3843
function transform_rule(::IntervalTransform, ::typeof(*), zL, zU, xL, xU, yL, yU)
3944
rl = Equation(zL, IfElse.ifelse(yL >= 0.0, #x*pos
4045
IfElse.ifelse(xL >= 0.0, xL*yL, #pos*pos

src/relaxation/relaxation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818
function var_names(::McCormickTransform, s::Real)
1919
return s, s
2020
end
21-
function var_names(::McCormickTransform, s::Term{Real, Nothing}) #Any terms like "Differential"
21+
function var_names(::McCormickTransform, s::Term{Real, Nothing}) #Any terms like "Differential" or x[1]
2222
if typeof(s.arguments[1])<:Term #then it has args
2323
args = Symbol[]
2424
for i in s.arguments[1].arguments

src/relaxation/rules.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
function transform_rule(::McCormickIntervalTransform, rule::Any, yL, yU, ycv, ycc, xL, xU, xcv, xcc)
1010
rL, rU = transform_rule(IntervalTransform(), rule, yL, yU, xL, xU)
11-
rcv, rcc = transform_rule(McCormickTransform(), rule, ycv, ycc, yL, yU, xcv, xcc, xL, xU)
11+
rcv, rcc = transform_rule(McCormickTransform(), rule, yL, yU, ycv, ycc, xL, xU, xcv, xcc)
1212
return rL, rU, rcv, rcc
1313
end
1414
function transform_rule(::McCormickIntervalTransform, rule::Any, zL, zU, zcv, zcc, xL, xU, xcv, xcc, yL, yU, ycv, ycc)
@@ -47,6 +47,11 @@ end
4747

4848
# Rules for multiplication adapted from:
4949
# https://github.com/PSORLab/McCormick.jl/blob/master/src/forward_operators/multiplication.jl
50+
function transform_rule(::McCormickTransform, ::typeof(*), zL, zU, zcv, zcc, xL, xU, xcv, xcc, yL::Real, yU::Real, ycv::Real, ycc::Real)
51+
rcv = Equation(zcv, IfElse.ifelse(yL >= 0.0, ycv*xcv, ycc*xcc))
52+
rcc = Equation(zcc, IfElse.ifelse(yL >= 0.0, ycc*xcc, ycv*xcv))
53+
return rcv, rcc
54+
end
5055
function transform_rule(::McCormickTransform, ::typeof(*), zL, zU, zcv, zcc, xL, xU, xcv, xcc, yL, yU, ycv, ycc)
5156
rcv = Equation(zcv, IfElse.ifelse(xL >= 0.0,
5257
IfElse.ifelse(yL >= 0.0, max(yU*xcv + xU*ycv - xU*yU, yL*xcv + xL*ycv - xL*yL),

src/transform/factor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
base_term(a::Any) = false
33
base_term(a::Term{Real, Base.ImmutableDict{DataType,Any}}) = true
4-
base_term(a::Term{Real, Nothing}) = true
4+
base_term(a::Term{Real, Nothing}) = (a.f==getindex)
55
base_term(a::Sym) = true
66
base_term(a::Real) = true
77

src/transform/utilities.jl

Lines changed: 139 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ function sub_2(a::SymbolicUtils.Pow)
5252
return a.exp
5353
end
5454

55-
sub_1(a::Term{Real, Nothing}) = a.arguments[1]
55+
function sub_1(a::Term{Real, Nothing})
56+
if a.f==getindex
57+
return a
58+
else
59+
return a.arguments[1]
60+
end
61+
end
5662
sub_2(a::Term{Real, Nothing}) = a.arguments[2]
5763

5864

@@ -366,18 +372,17 @@ shrink_eqs(eqs, 1)
366372
"""
367373
function shrink_eqs(eqs::Vector{Equation}, keep::Int64=4)
368374
new_eqs = eqs
369-
for i in 1:length(eqs)-keep
370-
new_eqs = substitute(new_eqs, Dict(new_eqs[i].lhs => new_eqs[i].rhs))
375+
for _ in 1:length(eqs)-keep
376+
new_eqs = substitute(new_eqs, Dict(new_eqs[1].lhs => new_eqs[1].rhs))[2:end]
371377
end
372-
new_eqs = new_eqs[end-(keep-1):end]
373378
return new_eqs
374379
end
375380

376381
"""
377-
convex_evaluator(::Equation)
378382
convex_evaluator(::Num)
383+
convex_evaluator(::Equation)
379384
380-
Given a symbolic equation or expression, return a function that evaluates
385+
Given a symbolic expression or equation, return a function that evaluates
381386
the convex relaxation of the expression or the equation's right-hand side
382387
and a list of correctly ordered arguments to this new function. To get
383388
evaluator functions for {lower bound, upper bound, convex relaxation,
@@ -427,75 +432,149 @@ out = evaluator.(x_cc, x_cv, x_hi, x_lo, y_cc, y_cv, y_hi, y_lo)
427432
as_array = Array(out)
428433
```
429434
"""
430-
function convex_evaluator(equation::Equation)
431-
# Apply the McCormickIntervalTransform to get expanded equations
432-
# defining the relaxations of the equation
433-
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
434-
435-
# Recursively substitute intermediate variables to get down to
436-
# 4 equations, representing the original equation's lower bound,
437-
# upper bound, convex relaxation, and concave relaxation
438-
step_2 = shrink_eqs(step_1)
439-
440-
# Extract all the variables from the smaller equation set and
441-
# organize them alphabetically
442-
ordered_vars = pull_vars(step_2)
443-
444-
# Create the new function. This works by calling Symbolics.build_function,
445-
# which creates a function as an Expr that evaluates build_function's first
446-
# argument, with the next argument(s) as the function's input(s). If we
447-
# set expression=Val{false}, build_function will return a compiled function
448-
# as a RuntimeGeneratedFunction, which we do NOT want as this is not
449-
# GPU-compatible. Instead, we keep expression=Val{true} (technically this is
450-
# the default) and we set new_func to be the evaluation of the returned Expr,
451-
# which is now a callable function. This line is delicate--don't change unless
452-
# you know what you're doing!
453-
@eval new_func = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
435+
function convex_evaluator(term::Num)
436+
# First, check to see if the term is "Add". If so, we can get some
437+
# huge time savings by separating out the expression using the knowledge
438+
# that the sum of convex relaxations is equal to the convex relaxation
439+
# of the sum (i.e., a_cv + b_cv = (a+b)_cv, and same for lo/hi/cc)
440+
if typeof(term.val) <: SymbolicUtils.Add
441+
# Start with any real-valued operands [if present]
442+
cv_eqn = term.val.coeff
443+
444+
# Loop through the dictionary of operands and treat each term like
445+
# its own equation
446+
for (key,val) in term.val.dict
447+
# "key" is the operand, "val" is its coefficient. The LHS of "equation" is irrelevant
448+
equation = 0 ~ (val*key)
449+
450+
# Apply the McCormick transform to expand out the equation with auxiliary
451+
# variables and get expressions for each variable's relaxations
452+
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
453+
454+
# Shrink the equations down to 4 total, for "lo", "hi", "cv", and "cc"
455+
step_2 = shrink_eqs(step_1)
456+
457+
# For "convex_evaluator" we only care about the convex part, which is #3 of 4.
458+
# See "all_evaluators" if you need more than just the convex relaxation
459+
cv_eqn += step_2[3].rhs
460+
end
454461

462+
# Scan through the equation and pick out and organize all variables needed as inputs
463+
ordered_vars = pull_vars(0 ~ cv_eqn)
464+
465+
# Create the evaluation function. This works by calling Symbolics.build_function,
466+
# which creates a function as an Expr that evaluates build_function's first
467+
# argument, with the next argument(s) as the function's input(s). If we
468+
# set expression=Val{false}, build_function will return a compiled function
469+
# as a RuntimeGeneratedFunction, which we do NOT want as this is not
470+
# GPU-compatible. Instead, we keep expression=Val{true} (technically this is
471+
# the default) and we set new_func to be the evaluation of the returned Expr,
472+
# which is now a callable function. This line is delicate--don't change unless
473+
# you know what you're doing!
474+
@eval new_func = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))
475+
else
476+
# Same as previous block, but without the speedup from a_cv + b_cv = (a+b)_cv
477+
equation = 0 ~ term
478+
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
479+
step_2 = shrink_eqs(step_1)
480+
ordered_vars = pull_vars(step_2)
481+
@eval new_func = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
482+
end
455483
return new_func, ordered_vars
456484
end
457485

458-
function convex_evaluator(term::Num)
459-
# Same as the version with ::Equation as the input, but allows for
460-
# more intuitive input. I.e., not making an equation where the LHS
461-
# is meaningless. Since we need an equation to apply transforms,
462-
# though, we just make an equation anyway (with a meaningless LHS)
463-
equation = 0 ~ term
464-
465-
# And now everything is the same as the other version of this function
466-
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
467-
step_2 = shrink_eqs(step_1)
468-
ordered_vars = pull_vars(step_2)
469-
@eval new_func = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
486+
function convex_evaluator(equation::Equation)
487+
# Same as when the input is `Num`, but we have to deal with the input
488+
# already being an equation (whose LHS is irrelevant)
489+
if typeof(equation.rhs.val) <: SymbolicUtils.Add
490+
cv_eqn = equation.rhs.val.coeff
491+
for (key,val) in equation.rhs.val.dict
492+
new_equation = 0 ~ (val*key)
493+
step_1 = apply_transform(McCormickIntervalTransform(), [new_equation])
494+
step_2 = shrink_eqs(step_1)
495+
cv_eqn += step_2[3].rhs
496+
end
497+
ordered_vars = pull_vars(0~cv_eqn)
498+
@eval new_func = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))
499+
500+
else
501+
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
502+
step_2 = shrink_eqs(step_1)
503+
ordered_vars = pull_vars(step_2)
504+
@eval new_func = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
505+
end
506+
470507
return new_func, ordered_vars
471508
end
472509

473510
"""
474-
all_evaluators(::Equation)
475511
all_evaluators(::Num)
512+
all_evaluators(::Equation)
476513
477514
See `convex_evaluator`. This function performs the same task, but returns
478515
four functions (representing lower bound, upper bound, convex relaxation,
479516
and concave relaxation evaluation functions) and the order vector.
480517
"""
481-
function all_evaluators(equation::Equation)
482-
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
483-
step_2 = shrink_eqs(step_1)
484-
ordered_vars = pull_vars(step_2)
485-
@eval lo_evaluator = $(build_function(step_2[1].rhs, ordered_vars..., expression=Val{true}))
486-
@eval hi_evaluator = $(build_function(step_2[2].rhs, ordered_vars..., expression=Val{true}))
487-
@eval cv_evaluator = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
488-
@eval cc_evaluator = $(build_function(step_2[4].rhs, ordered_vars..., expression=Val{true}))
518+
function all_evaluators(term::Num)
519+
if typeof(term.val) <: SymbolicUtils.Add
520+
lo_eqn = term.val.coeff
521+
hi_eqn = term.val.coeff
522+
cv_eqn = term.val.coeff
523+
cc_eqn = term.val.coeff
524+
for (key,val) in term.val.dict
525+
equation = 0 ~ (val*key)
526+
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
527+
step_2 = shrink_eqs(step_1)
528+
lo_eqn += step_2[1].rhs
529+
hi_eqn += step_2[2].rhs
530+
cv_eqn += step_2[3].rhs
531+
cc_eqn += step_2[3].rhs
532+
end
533+
ordered_vars = pull_vars(step_2)
534+
@eval lo_evaluator = $(build_function(lo_eqn, ordered_vars..., expression=Val{true}))
535+
@eval hi_evaluator = $(build_function(hi_eqn, ordered_vars..., expression=Val{true}))
536+
@eval cv_evaluator = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))
537+
@eval cc_evaluator = $(build_function(cc_eqn, ordered_vars..., expression=Val{true}))
538+
else
539+
equation = 0 ~ term
540+
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
541+
step_2 = shrink_eqs(step_1)
542+
ordered_vars = pull_vars(step_2)
543+
@eval lo_evaluator = $(build_function(step_2[1].rhs, ordered_vars..., expression=Val{true}))
544+
@eval hi_evaluator = $(build_function(step_2[2].rhs, ordered_vars..., expression=Val{true}))
545+
@eval cv_evaluator = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
546+
@eval cc_evaluator = $(build_function(step_2[4].rhs, ordered_vars..., expression=Val{true}))
547+
end
489548
return lo_evaluator, hi_evaluator, cv_evaluator, cc_evaluator, ordered_vars
490549
end
491-
function all_evaluators(term::Num)
492-
equation = 0 ~ term
493-
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
494-
step_2 = shrink_eqs(step_1)
495-
ordered_vars = pull_vars(step_2)
496-
@eval lo_evaluator = $(build_function(step_2[1].rhs, ordered_vars..., expression=Val{true}))
497-
@eval hi_evaluator = $(build_function(step_2[2].rhs, ordered_vars..., expression=Val{true}))
498-
@eval cv_evaluator = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
499-
@eval cc_evaluator = $(build_function(step_2[4].rhs, ordered_vars..., expression=Val{true}))
550+
function all_evaluators(equation::Equation)
551+
if typeof(equation.rhs.val) <: SymbolicUtils.Add
552+
lo_eqn = equation.rhs.val.coeff
553+
hi_eqn = equation.rhs.val.coeff
554+
cv_eqn = equation.rhs.val.coeff
555+
cc_eqn = equation.rhs.val.coeff
556+
for (key,val) in equation.rhs.val.dict
557+
new_equation = 0 ~ (val*key)
558+
step_1 = apply_transform(McCormickIntervalTransform(), [new_equation])
559+
step_2 = shrink_eqs(step_1)
560+
lo_eqn += step_2[1].rhs
561+
hi_eqn += step_2[2].rhs
562+
cv_eqn += step_2[3].rhs
563+
cc_eqn += step_2[3].rhs
564+
end
565+
ordered_vars = pull_vars(step_2)
566+
@eval lo_evaluator = $(build_function(lo_eqn, ordered_vars..., expression=Val{true}))
567+
@eval hi_evaluator = $(build_function(hi_eqn, ordered_vars..., expression=Val{true}))
568+
@eval cv_evaluator = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))
569+
@eval cc_evaluator = $(build_function(cc_eqn, ordered_vars..., expression=Val{true}))
570+
else
571+
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
572+
step_2 = shrink_eqs(step_1)
573+
ordered_vars = pull_vars(step_2)
574+
@eval lo_evaluator = $(build_function(step_2[1].rhs, ordered_vars..., expression=Val{true}))
575+
@eval hi_evaluator = $(build_function(step_2[2].rhs, ordered_vars..., expression=Val{true}))
576+
@eval cv_evaluator = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
577+
@eval cc_evaluator = $(build_function(step_2[4].rhs, ordered_vars..., expression=Val{true}))
578+
end
500579
return lo_evaluator, hi_evaluator, cv_evaluator, cc_evaluator, ordered_vars
501580
end

0 commit comments

Comments
 (0)