@@ -11,7 +11,39 @@ function simplify_constants(O::Operation)
11
11
_O
12
12
end
13
13
14
+ const TREE_SHRINK_OPS = [:* ,:+ ]
15
+
14
16
function _simplify_constants (O)
17
+ # Tree shrinking
18
+ for cur_op in TREE_SHRINK_OPS
19
+ if Symbol (O. op) == cur_op
20
+ # Shrink tree
21
+ if any (x-> typeof (x)<: Operation && Symbol (x. op) == cur_op ,O. args)
22
+ idxs = find (x-> typeof (x)<: Operation && Symbol (x. op) == cur_op,O. args)
23
+ keep_idxs = 1 : length (O. args) .∉ (idxs,)
24
+ args = Vector{Expression}[O. args[i]. args for i in idxs]
25
+ push! (args,O. args[keep_idxs])
26
+ return Operation (O. op,vcat (args... ))
27
+ # Collapse constants
28
+ elseif length (find (x-> typeof (x)<: Variable && x. subtype == :Constant ,O. args)) > 1
29
+ idxs = find (x-> typeof (x)<: Variable && x. subtype == :Constant ,O. args)
30
+ other_idxs = 1 : length (O. args) .∉ (idxs,)
31
+ if cur_op == :*
32
+ new_var = Constant (prod (x-> x. value,O. args[idxs]))
33
+ elseif cur_op == :+
34
+ new_var = Constant (sum (x-> x. value,O. args[idxs]))
35
+ end
36
+ new_args = O. args[other_idxs]
37
+ push! (new_args,new_var)
38
+ if length (new_args) > 1
39
+ return Operation (O. op,new_args)
40
+ else
41
+ return new_args[1 ]
42
+ end
43
+ end
44
+ end
45
+ end
46
+
15
47
if Symbol (O. op) == :*
16
48
# If any variable is `Constant(0)`, zero the whole thing
17
49
# If any variable is `Constant(1)`, remove that `Constant(1)` unless
@@ -31,7 +63,9 @@ function _simplify_constants(O)
31
63
else
32
64
return O
33
65
end
34
- elseif Symbol (O. op) == :+ && any (x-> typeof (x)<: Variable && (isequal (x,Constant (0 )) || isequal (x,Constant (- 0 ))),O. args)
66
+ elseif (Symbol (O. op) == :+ || Symbol (O. op) == :- ) &&
67
+ any (x-> typeof (x)<: Variable && (isequal (x,Constant (0 )) ||
68
+ isequal (x,Constant (- 0 ))),O. args)
35
69
# If there are Constant(0)s in a big `+` expression, get rid of them
36
70
idxs = find (x-> typeof (x)<: Variable && (isequal (x,Constant (0 )) || isequal (x,Constant (- 0 ))),O. args)
37
71
_O = Operation (O. op,O. args[1 : length (O. args) .∉ (idxs,)])
0 commit comments