Skip to content

Commit 38e1fba

Browse files
committed
Added some support for Symbolics
Both binarize!() and factor!() now work with SymbolicUtils expressions of type {Add, Mul, Div, Pow}. More will be added later as necessary, but the basic functionality for these should be ready.
1 parent 95f7798 commit 38e1fba

File tree

3 files changed

+224
-17
lines changed

3 files changed

+224
-17
lines changed

src/transform/binarize.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,41 @@ function binarize!(ex::Expr)
1313
resize!(ex.args, 3)
1414
end
1515
return ex
16-
end
16+
end
17+
18+
function binarize!(ex::SymbolicUtils.Add)
19+
(arity(ex) < 3) && return ex
20+
# Op is already +
21+
skipfirst = iszero(ex.coeff)
22+
newdict = Dict{Any, Number}()
23+
for (key, val) in ex.dict
24+
if skipfirst
25+
skipfirst = false
26+
continue
27+
end
28+
newdict[key] = val
29+
delete!(ex.dict, key)
30+
end
31+
a = SymbolicUtils.Add(Real, 0, newdict)
32+
binarize!(a)
33+
ex.dict[a] = 1
34+
return nothing
35+
end
36+
function binarize!(ex::SymbolicUtils.Mul)
37+
(arity(ex) < 3) && return ex
38+
# Op is already *
39+
skipfirst = isone(ex.coeff)
40+
newdict = Dict{Any, Number}()
41+
for (key, val) in ex.dict
42+
if skipfirst
43+
skipfirst = false
44+
continue
45+
end
46+
newdict[key] = val
47+
delete!(ex.dict, key)
48+
end
49+
a = SymbolicUtils.Mul(Real, 1, newdict)
50+
binarize!(a)
51+
ex.dict[a] = 1
52+
return nothing
53+
end

src/transform/factor.jl

Lines changed: 184 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,43 @@ function is_factor(ex::Expr)
2121
(ex.head !== :call) && error("Only call expressions in rhs are currently supported.")
2222
num_or_var(ex.args[2:end])
2323
end
24+
function isfactor(ex::SymbolicUtils.Add)
25+
(~iszero(ex.coeff)) && (length(ex.dict)>1) && return false
26+
(iszero(ex.coeff)) && (length(ex.dict)>2) && return false
27+
for (key, val) in ex.dict
28+
~(isone(val)) && return false
29+
~(typeof(key)<:Term) && return false
30+
end
31+
return true
32+
end
33+
function isfactor(ex::SymbolicUtils.Mul)
34+
(~isone(ex.coeff)) && (length(ex.dict)>1) && return false
35+
(isone(ex.coeff)) && (length(ex.dict)>2) && return false
36+
for (key, val) in ex.dict
37+
~(isone(val)) && return false
38+
~(typeof(key)<:Term) && return false
39+
end
40+
return true
41+
end
42+
function isfactor(ex::SymbolicUtils.Div)
43+
~(typeof(ex.num)<:Term) && ~(typeof(ex.num)<:Real) && return false
44+
~(typeof(ex.den)<:Term) && ~(typeof(ex.num)<:Real) && return false
45+
return true
46+
end
47+
function isfactor(ex::SymbolicUtils.Pow)
48+
~(typeof(ex.base)<:Term) && ~(typeof(ex.base)<:Real) && return false
49+
~(typeof(ex.exp)<:Term) && ~(typeof(ex.exp)<:Real) && return false
50+
return true
51+
end
52+
2453

2554
# factor!(ex::Number; assignments::Vector{Assignment}) = assignments
2655
# factor!(ex::Symbol; assignments::Vector{Assignment}) = assignments
2756
factor!(ex::NTuple; assignments::Vector{Assignment}) = factor!(Expr(:($([i for i in ex]...))), assignments=assignments)
2857
factor!(ex::Tuple; assignments::Vector{Assignment}) = factor!(Expr(:($([i for i in ex]...))), assignments=assignments)
2958

3059
function factor!(ex::Number; assignments = Assignment[])
60+
println("Inside this one")
3161
index = findall(x -> x.rhs==ex, assignments)
3262
if isempty(index)
3363
newsym = gensym(:aux)
@@ -91,24 +121,162 @@ function factor!(ex::Expr; assignments = Assignment[])
91121
factor!(Expr(:($([i for i in new_expr]...))), assignments=assignments)
92122
return assignments
93123
end
124+
function factor!(ex::SymbolicUtils.Add; assignments = Assignment[])
125+
binarize!(ex)
126+
if isfactor(ex)
127+
index = findall(x -> isequal(x.rhs,ex), assignments)
128+
if isempty(index)
129+
newsym = gensym(:aux)
130+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
131+
newvar = genvar(newsym)
132+
new = Assignment(Symbolics.value(newvar), ex)
133+
push!(assignments, new)
134+
else
135+
p = collect(1:length(assignments))
136+
deleteat!(p, index[1])
137+
push!(p, index[1])
138+
assignments[:] = assignments[p]
139+
end
140+
return assignments
141+
end
94142

143+
new_terms = Dict{Any, Number}()
144+
for (key, val) in ex.dict
145+
if (typeof(key)<:Term) && isone(val)
146+
new_terms[key] = val
147+
elseif (typeof(key)<:Term)
148+
index = findall(x -> isequal(x.rhs,val*key), assignments)
149+
if isempty(index)
150+
newsym = gensym(:aux)
151+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
152+
newvar = genvar(newsym)
153+
new = Assignment(Symbolics.value(newvar), val*key)
154+
push!(assignments, new)
155+
new_terms[Symbolics.value(newvar)] = 1
156+
else
157+
new_terms[assignments[index[1]].lhs] = 1
158+
end
159+
else
160+
factor!(key, assignments=assignments)
161+
new_terms[assignments[end].lhs] = 1
162+
end
163+
end
164+
new_add = SymbolicUtils.Add(Real, ex.coeff, new_terms)
165+
factor!(new_add, assignments=assignments)
166+
return assignments
167+
end
168+
function factor!(ex::SymbolicUtils.Mul; assignments = Assignment[])
169+
binarize!(ex)
170+
if isfactor(ex)
171+
index = findall(x -> isequal(x.rhs,ex), assignments)
172+
if isempty(index)
173+
newsym = gensym(:aux)
174+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
175+
newvar = genvar(newsym)
176+
new = Assignment(Symbolics.value(newvar), ex)
177+
push!(assignments, new)
178+
else
179+
p = collect(1:length(assignments))
180+
deleteat!(p, index[1])
181+
push!(p, index[1])
182+
assignments[:] = assignments[p]
183+
end
184+
return assignments
185+
end
95186

96-
# Works with this example:
97-
# x + y + z/x
98-
# ex = Expr(:call, :+, :x, :y, Expr(:call, :/, :z, :x))
99-
# a = factor!(ex)
187+
new_terms = Dict{Any, Number}()
188+
for (key, val) in ex.dict
189+
if (typeof(key)<:Term) && isone(val)
190+
new_terms[key] = val
191+
elseif (typeof(key)<:Term)
192+
index = findall(x -> isequal(x.rhs,key^val), assignments)
193+
if isempty(index)
194+
newsym = gensym(:aux)
195+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
196+
newvar = genvar(newsym)
197+
new = Assignment(Symbolics.value(newvar), key^val)
198+
push!(assignments, new)
199+
new_terms[Symbolics.value(newvar)] = 1
200+
else
201+
new_terms[assignments[index[1]].lhs] = 1
202+
end
203+
else
204+
factor!(key, assignments=assignments)
205+
new_terms[assignments[end].lhs] = 1
206+
end
207+
end
208+
new_mul = SymbolicUtils.Mul(Real, ex.coeff, new_terms)
209+
factor!(new_mul, assignments=assignments)
210+
return assignments
211+
end
212+
function factor!(ex::SymbolicUtils.Pow; assignments = Assignment[])
213+
if isfactor(ex)
214+
index = findall(x -> isequal(x.rhs,ex), assignments)
215+
if isempty(index)
216+
newsym = gensym(:aux)
217+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
218+
newvar = genvar(newsym)
219+
new = Assignment(Symbolics.value(newvar), ex)
220+
push!(assignments, new)
221+
else
222+
p = collect(1:length(assignments))
223+
deleteat!(p, index[1])
224+
push!(p, index[1])
225+
assignments[:] = assignments[p]
226+
end
227+
return assignments
228+
end
229+
230+
if typeof(ex.base)<:Term
231+
new_base = ex.base
232+
else
233+
factor!(ex.base, assignments=assignments)
234+
new_base = assignments[end].lhs
235+
end
236+
if typeof(ex.exp)<:Term
237+
new_exp = ex.exp
238+
else
239+
factor!(ex.exp, assignments=assignments)
240+
new_exp = assignments[end].lhs
241+
end
242+
new_pow = SymbolicUtils.Pow(new_base, new_exp)
243+
factor!(new_pow, assignments=assignments)
244+
return assignments
245+
end
246+
function factor!(ex::SymbolicUtils.Div; assignments = Assignment[])
247+
if isfactor(ex)
248+
index = findall(x -> isequal(x.rhs,ex), assignments)
249+
if isempty(index)
250+
newsym = gensym(:aux)
251+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
252+
newvar = genvar(newsym)
253+
new = Assignment(Symbolics.value(newvar), ex)
254+
push!(assignments, new)
255+
else
256+
p = collect(1:length(assignments))
257+
deleteat!(p, index[1])
258+
push!(p, index[1])
259+
assignments[:] = assignments[p]
260+
end
261+
return assignments
262+
end
263+
264+
if typeof(ex.num)<:Term
265+
new_num = ex.num
266+
else
267+
factor!(ex.num, assignments=assignments)
268+
new_num = assignments[end].lhs
269+
end
270+
if typeof(ex.den)<:Term
271+
new_den = ex.den
272+
else
273+
factor!(ex.den, assignments=assignments)
274+
new_den = assignments[end].lhs
275+
end
276+
new_div = SymbolicUtils.Div(new_num, new_den)
277+
factor!(new_div, assignments=assignments)
278+
return assignments
279+
end
100280

101-
# Now this works too
102-
# x^3 + 5*y*x^2 + 3*y^2*x + 15*y^3 + 20
103-
# ex2 = Expr(:call, :+, Expr(:call, :^, :x, 3), Expr(:call, :*, 5, :y, Expr(:call, :^, :x, 2)), Expr(:call, :*, 3, Expr(:call, :^, :y, 2), :x), Expr(:call, :*, 15, Expr(:call, :^, :y, 3)), 20)
104-
# b = factor!(ex2)
105281

106-
# This example also seems to be working properly
107-
# (x/(y/z)) + (y/z) + (z/x)
108-
# ex3 = Expr(:call, :+, Expr(:call, :/, :x, Expr(:call, :/, :y, :z)), Expr(:call, :/, :y, :z), Expr(:call, :/, :z, :x))
109-
# c = factor!(ex3)
110282

111-
# This is correct as well
112-
# (a/(b/(c/(d/(e/(f/g))))))
113-
# ex4 = Expr(:call, :/, :a, Expr(:call, :/, :b, Expr(:call, :/, :c, Expr(:call, :/, :d, Expr(:call, :/, :e, Expr(:call, :/, :f, :g))))))
114-
# d = factor!(ex4)

src/transform/utilities.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ first(a::Expr) = a.args
33
arity(a::Expr) = length(a.args) - 1
44
arity(a::Assignment) = arity(a.rhs)
55
arity(a::Number) = 1
6+
arity(a::SymbolicUtils.Add) = length(a.dict) + (~iszero(a.coeff))
7+
arity(a::SymbolicUtils.Mul) = length(a.dict) + (~isone(a.coeff))
68

79
op(a::Expr) = a.args[1]
810
op(a::Assignment) = op(a.rhs)

0 commit comments

Comments
 (0)