Skip to content

Commit b1d228c

Browse files
committed
Addition of Sym and some error correction
With the new "splitting" of addition, there are sometimes equations where the RHS is just a symbol. SCMC previously couldn't handle this properly; now it can. Errors included not referring to equation types correctly, and accidentally making convex/concave portions the same in all_evaluators, resulting in the convex functions not actually being convex. This has been fixed.
1 parent 74afec2 commit b1d228c

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

src/transform/factor.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,23 @@ function isfactor(ex::Term{Real,Nothing})
4141
return true
4242
end
4343

44+
function factor!(ex::Sym{Real, Base.ImmutableDict{DataType, Any}}; eqs = Equation[])
45+
index = findall(x -> isequal(x.rhs,ex), eqs)
46+
if isempty(index)
47+
newsym = gensym(:aux)
48+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
49+
newvar = genvar(newsym)
50+
new = Equation(Symbolics.value(newvar), ex)
51+
push!(eqs, new)
52+
else
53+
p = collect(1:length(eqs))
54+
deleteat!(p, index[1])
55+
push!(p, index[1])
56+
eqs[:] = eqs[p]
57+
end
58+
return eqs
59+
end
60+
4461
function factor!(ex::SymbolicUtils.Add; eqs = Equation[])
4562
binarize!(ex)
4663
if isfactor(ex)

src/transform/utilities.jl

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
arity(a::Equation) = arity(a.rhs)
33
arity(a::Term{Real, Base.ImmutableDict{DataType,Any}}) = 1
44
arity(a::Term{Real, Nothing}) = 1
5+
arity(a::Sym{Real, Base.ImmutableDict{DataType,Any}}) = 1
56
arity(a::SymbolicUtils.Add) = length(a.dict) + (~iszero(a.coeff))
67
arity(a::SymbolicUtils.Mul) = length(a.dict) + (~isone(a.coeff))
78
arity(a::SymbolicUtils.Pow) = 2
@@ -14,12 +15,14 @@ op(::SymbolicUtils.Pow) = ^
1415
op(::SymbolicUtils.Div) = /
1516
op(::Term{Real, Base.ImmutableDict{DataType,Any}}) = nothing
1617
op(a::Term{Real, Nothing}) = a.f
18+
op(a::Sym{Real, Base.ImmutableDict{DataType,Any}}) = getindex
1719

1820
xstr(a::Equation) = sub_1(a.rhs)
1921
ystr(a::Equation) = sub_2(a.rhs)
2022
zstr(a::Equation) = a.lhs
2123

2224
sub_1(a::Term{Real, Base.ImmutableDict{DataType,Any}}) = a
25+
sub_1(a::Sym{Real, Base.ImmutableDict{DataType,Any}}) = a
2326
function sub_1(a::SymbolicUtils.Add)
2427
sorted_dict = sort(collect(a.dict), by=x->string(x[1]))
2528
return sorted_dict[1].first
@@ -528,9 +531,9 @@ function all_evaluators(term::Num)
528531
lo_eqn += step_2[1].rhs
529532
hi_eqn += step_2[2].rhs
530533
cv_eqn += step_2[3].rhs
531-
cc_eqn += step_2[3].rhs
534+
cc_eqn += step_2[4].rhs
532535
end
533-
ordered_vars = pull_vars(step_2)
536+
ordered_vars = pull_vars(0 ~ cv_eqn)
534537
@eval lo_evaluator = $(build_function(lo_eqn, ordered_vars..., expression=Val{true}))
535538
@eval hi_evaluator = $(build_function(hi_eqn, ordered_vars..., expression=Val{true}))
536539
@eval cv_evaluator = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))
@@ -548,25 +551,30 @@ function all_evaluators(term::Num)
548551
return lo_evaluator, hi_evaluator, cv_evaluator, cc_evaluator, ordered_vars
549552
end
550553
function all_evaluators(equation::Equation)
551-
if typeof(equation.rhs.val) <: SymbolicUtils.Add
552-
lo_eqn = equation.rhs.val.coeff
553-
hi_eqn = equation.rhs.val.coeff
554-
cv_eqn = equation.rhs.val.coeff
555-
cc_eqn = equation.rhs.val.coeff
556-
for (key,val) in equation.rhs.val.dict
554+
if typeof(equation.rhs) <: SymbolicUtils.Add
555+
lo_eqn = equation.rhs.coeff
556+
hi_eqn = equation.rhs.coeff
557+
cv_eqn = equation.rhs.coeff
558+
cc_eqn = equation.rhs.coeff
559+
for (key,val) in equation.rhs.dict
557560
new_equation = 0 ~ (val*key)
558561
step_1 = apply_transform(McCormickIntervalTransform(), [new_equation])
559562
step_2 = shrink_eqs(step_1)
560563
lo_eqn += step_2[1].rhs
561564
hi_eqn += step_2[2].rhs
562565
cv_eqn += step_2[3].rhs
563-
cc_eqn += step_2[3].rhs
566+
cc_eqn += step_2[4].rhs
564567
end
565-
ordered_vars = pull_vars(step_2)
568+
ordered_vars = pull_vars(0 ~ cv_eqn)
566569
@eval lo_evaluator = $(build_function(lo_eqn, ordered_vars..., expression=Val{true}))
567570
@eval hi_evaluator = $(build_function(hi_eqn, ordered_vars..., expression=Val{true}))
568571
@eval cv_evaluator = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))
569572
@eval cc_evaluator = $(build_function(cc_eqn, ordered_vars..., expression=Val{true}))
573+
574+
@show length(string(lo_eqn))
575+
@show length(string(hi_eqn))
576+
@show length(string(cv_eqn))
577+
@show length(string(cc_eqn))
570578
else
571579
step_1 = apply_transform(McCormickIntervalTransform(), [equation])
572580
step_2 = shrink_eqs(step_1)
@@ -575,6 +583,11 @@ function all_evaluators(equation::Equation)
575583
@eval hi_evaluator = $(build_function(step_2[2].rhs, ordered_vars..., expression=Val{true}))
576584
@eval cv_evaluator = $(build_function(step_2[3].rhs, ordered_vars..., expression=Val{true}))
577585
@eval cc_evaluator = $(build_function(step_2[4].rhs, ordered_vars..., expression=Val{true}))
586+
587+
@show length(string(step_2[1].rhs))
588+
@show length(string(step_2[2].rhs))
589+
@show length(string(step_2[3].rhs))
590+
@show length(string(step_2[4].rhs))
578591
end
579592
return lo_evaluator, hi_evaluator, cv_evaluator, cc_evaluator, ordered_vars
580593
end

0 commit comments

Comments
 (0)