Skip to content

Commit 6564654

Browse files
committed
Accept Num argument for pull_vars; fix variable list issue
1 parent b1d228c commit 6564654

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

src/transform/factor.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ function isfactor(ex::Term{Real,Nothing})
4141
return true
4242
end
4343

44+
factor!(ex::Num) = factor!(ex.val)
4445
function factor!(ex::Sym{Real, Base.ImmutableDict{DataType, Any}}; eqs = Equation[])
4546
index = findall(x -> isequal(x.rhs,ex), eqs)
4647
if isempty(index)

src/transform/utilities.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,13 @@ end
213213

214214

215215
"""
216-
pull_vars(eqn::Equation)
217-
pull_vars(eqns::Vector{Equation})
216+
pull_vars(::Num)
217+
pull_vars(::Vector{Num})
218+
pull_vars(::Equation)
219+
pull_vars(::Vector{Equation})
218220
219-
Pull out all variables/symbols from the RHS of an equation or set
220-
of equations and sorts them alphabetically.
221+
Pull out all variables/symbols from an expression or the RHS of an
222+
equation (or RHSs of a set of equations), and sort them alphabetically.
221223
222224
# Example
223225
@@ -230,6 +232,24 @@ julia> pull_vars(func)
230232
z
231233
```
232234
"""
235+
function pull_vars(term::Num)
236+
vars = Num[]
237+
strings = String[]
238+
vars, strings = _pull_vars(term, vars, strings)
239+
vars = vars[sortperm(strings)]
240+
return vars
241+
end
242+
243+
function pull_vars(terms::Vector{Num})
244+
vars = Num[]
245+
strings = String[]
246+
for term in terms
247+
vars, strings = _pull_vars(term, vars, strings)
248+
end
249+
vars = vars[sortperm(strings)]
250+
return vars
251+
end
252+
233253
function pull_vars(eqn::Equation)
234254
vars = Num[]
235255
strings = String[]
@@ -463,7 +483,7 @@ function convex_evaluator(term::Num)
463483
end
464484

465485
# Scan through the equation and pick out and organize all variables needed as inputs
466-
ordered_vars = pull_vars(0 ~ cv_eqn)
486+
ordered_vars = pull_vars(cv_eqn)
467487

468488
# Create the evaluation function. This works by calling Symbolics.build_function,
469489
# which creates a function as an Expr that evaluates build_function's first
@@ -497,7 +517,7 @@ function convex_evaluator(equation::Equation)
497517
step_2 = shrink_eqs(step_1)
498518
cv_eqn += step_2[3].rhs
499519
end
500-
ordered_vars = pull_vars(0~cv_eqn)
520+
ordered_vars = pull_vars(cv_eqn)
501521
@eval new_func = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))
502522

503523
else
@@ -533,7 +553,7 @@ function all_evaluators(term::Num)
533553
cv_eqn += step_2[3].rhs
534554
cc_eqn += step_2[4].rhs
535555
end
536-
ordered_vars = pull_vars(0 ~ cv_eqn)
556+
ordered_vars = pull_vars([0~lo_eqn, 0~hi_eqn, 0~cv_eqn, 0~cc_eqn])
537557
@eval lo_evaluator = $(build_function(lo_eqn, ordered_vars..., expression=Val{true}))
538558
@eval hi_evaluator = $(build_function(hi_eqn, ordered_vars..., expression=Val{true}))
539559
@eval cv_evaluator = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))
@@ -565,7 +585,7 @@ function all_evaluators(equation::Equation)
565585
cv_eqn += step_2[3].rhs
566586
cc_eqn += step_2[4].rhs
567587
end
568-
ordered_vars = pull_vars(0 ~ cv_eqn)
588+
ordered_vars = pull_vars([0~lo_eqn, 0~hi_eqn, 0~cv_eqn, 0~cc_eqn])
569589
@eval lo_evaluator = $(build_function(lo_eqn, ordered_vars..., expression=Val{true}))
570590
@eval hi_evaluator = $(build_function(hi_eqn, ordered_vars..., expression=Val{true}))
571591
@eval cv_evaluator = $(build_function(cv_eqn, ordered_vars..., expression=Val{true}))

0 commit comments

Comments
 (0)