1
- function simplify_constants (O:: Operation ,shorten_tree = true )
1
+ function simplify_constants (O:: Operation , shorten_tree = true )
2
2
O_last = nothing
3
3
_O = O
4
4
while _O != O_last
@@ -11,79 +11,62 @@ function simplify_constants(O::Operation,shorten_tree = true)
11
11
_O
12
12
end
13
13
14
- const TREE_SHRINK_OPS = [* , + ]
14
+ const AC_OPERATORS = [* , + ]
15
15
16
- function _simplify_constants (O,shorten_tree = true )
16
+ function _simplify_constants (O, shorten_tree = true )
17
17
# Tree shrinking
18
- if shorten_tree
19
- for cur_op in TREE_SHRINK_OPS
20
- if O. op == cur_op
21
- # Shrink tree
22
- if any (x -> is_operation (x) && x. op == cur_op, O. args)
23
- idxs = findall (x -> is_operation (x) && x. op == cur_op, O. args)
24
- keep_idxs = 1 : length (O. args) .∉ (idxs,)
25
- args = Vector{Expression}[O. args[i]. args for i in idxs]
26
- push! (args,O. args[keep_idxs])
27
- return Operation (O. op, vcat (args... ))
28
- end
29
- # Collapse constants
30
- idxs = findall (is_constant, O. args)
31
- if length (idxs) > 1
32
- other_idxs = 1 : length (O. args) .∉ (idxs,)
33
- if cur_op == (* )
34
- new_var = Constant (prod (get, O. args[idxs]))
35
- elseif cur_op == (+ )
36
- new_var = Constant (sum (get, O. args[idxs]))
37
- end
38
- new_args = O. args[other_idxs]
39
- push! (new_args,new_var)
40
- if length (new_args) > 1
41
- return Operation (O. op,new_args)
42
- else
43
- return new_args[1 ]
44
- end
45
- end
46
- end
18
+ if shorten_tree && O. op ∈ AC_OPERATORS
19
+ # Flatten tree
20
+ idxs = findall (x -> is_operation (x) && x. op == O. op, O. args)
21
+ if ! isempty (idxs)
22
+ keep_idxs = eachindex (O. args) .∉ Ref (idxs)
23
+ args = Vector{Expression}[O. args[i]. args for i in idxs]
24
+ push! (args, O. args[keep_idxs])
25
+ return Operation (O. op, vcat (args... ))
26
+ end
27
+
28
+ # Collapse constants
29
+ idxs = findall (is_constant, O. args)
30
+ if length (idxs) > 1
31
+ other_idxs = eachindex (O. args) .∉ (idxs,)
32
+ new_var = Constant (mapreduce (get, O. op, O. args[idxs]))
33
+ new_args = O. args[other_idxs]
34
+ push! (new_args,new_var)
35
+
36
+ return length (new_args) > 1 ? Operation (O. op, new_args) : first (new_args)
47
37
end
48
38
end
49
39
50
- if O. op == (* )
40
+ if O. op === (* )
51
41
# If any variable is `Constant(0)`, zero the whole thing
42
+ any (iszero, O. args) && return Constant (0 )
43
+
52
44
# If any variable is `Constant(1)`, remove that `Constant(1)` unless
53
45
# they are all `Constant(1)`, in which case simplify to a single variable
54
- if any (iszero, O. args)
55
- return Constant (0 )
56
- elseif any (isone, O. args)
57
- idxs = findall (isone, O. args)
58
- _O = Operation (O. op,O. args[1 : length (O. args) .∉ (idxs,)])
59
- if isempty (_O. args)
60
- return Constant (1 )
61
- elseif length (_O. args) == 1
62
- return _O. args[1 ]
63
- else
64
- return _O
65
- end
66
- else
67
- return O
68
- end
69
- elseif O. op == (+ ) && any (iszero, O. args)
70
- # If there are Constant(0)s in a big `+` expression, get rid of them
71
- idxs = findall (iszero, O. args)
72
- _O = Operation (O. op,O. args[1 : length (O. args) .∉ (idxs,)])
73
- if isempty (_O. args)
74
- return Constant (0 )
75
- elseif length (_O. args) == 1
76
- return _O. args[1 ]
77
- else
78
- return O
46
+ if any (isone, O. args)
47
+ args = filter (! isone, O. args)
48
+
49
+ isempty (args) && return Constant (1 )
50
+ length (args) == 1 && return first (args)
51
+ return Operation (O. op, args)
79
52
end
80
- elseif O. op == identity
81
- return O. args[1 ]
82
- elseif O. op == (- ) && length (O. args) == 1
83
- return Operation (* ,Expression[- 1 ,O. args[1 ]])
84
- else
53
+
85
54
return O
86
55
end
56
+
57
+ if O. op === (+ ) && any (iszero, O. args)
58
+ # If there are Constant(0)s in a big `+` expression, get rid of them
59
+ args = filter (! iszero, O. args)
60
+
61
+ isempty (args) && return Constant (0 )
62
+ length (args) == 1 && return first (args)
63
+ return Operation (O. op, args)
64
+ end
65
+
66
+ (O. op, length (O. args)) === (identity, 1 ) && return O. args[1 ]
67
+
68
+ (O. op, length (O. args)) === (- , 1 ) && return Operation (* , Expression[- 1 , O. args[1 ]])
69
+
87
70
return O
88
71
end
89
72
simplify_constants (x:: Variable ,y= false ) = x
0 commit comments