@@ -4,105 +4,122 @@ import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set
44import .. NodeUtilsModule: tree_mapreduce, is_node_constant
55import .. OperatorEnumModule: AbstractOperatorEnum
66import .. ValueInterfaceModule: is_valid
7+ import .. EvaluateModule: get_nbin
78
89_una_op_kernel (f:: F , l:: T ) where {F,T} = f (l)
910_bin_op_kernel (f:: F , l:: T , r:: T ) where {F,T} = f (l, r)
1011
12+ # Operator traits
1113is_commutative (:: typeof (* )) = true
1214is_commutative (:: typeof (+ )) = true
1315is_commutative (_) = false
1416
15- is_subtraction (:: typeof (- )) = true
16- is_subtraction (_) = false
17+ # Zero-related traits
18+ has_right_identity_zero (:: typeof (+ )) = true
19+ has_right_identity_zero (:: typeof (- )) = true
20+ has_right_identity_zero (_) = false
21+
22+ has_left_identity_zero (:: typeof (+ )) = true
23+ has_left_identity_zero (_) = false
24+
25+ absorbed_by_zero (:: typeof (* )) = true
26+ absorbed_by_zero (_) = false
27+
28+ # One-related traits
29+ has_identity_one (:: typeof (* )) = true
30+ has_identity_one (_) = false
31+
32+ # Self-operation traits
33+ simplifies_given_equal_operands (:: typeof (/ )) = true
34+ simplifies_given_equal_operands (_) = false
1735
1836combine_operators (tree:: AbstractExpressionNode , :: AbstractOperatorEnum ) = tree
19- # This is only defined for `Node` as it is not possible for, e.g.,
20- # `GraphNode`.
37+
2138function combine_operators (tree:: Node{T} , operators:: AbstractOperatorEnum ) where {T}
22- # NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.
23- # ((const + var) + const) => (const + var)
24- # ((const * var) * const) => (const * var)
25- # ((const - var) - const) => (const - var)
26- # (want to add anything commutative!)
27- # TODO - need to combine plus/sub if they are both there.
28- if tree. degree == 0
29- return tree
30- elseif tree. degree == 1
31- tree. l = combine_operators (tree. l, operators)
32- elseif tree. degree == 2
33- tree. l = combine_operators (tree. l, operators)
34- tree. r = combine_operators (tree. r, operators)
39+ deg = tree. degree
40+ deg == 0 && return tree
41+ tree. l = combine_operators (tree. l, operators)
42+ deg == 1 && return tree
43+ tree. r = combine_operators (tree. r, operators)
44+ return dispatch_deg2_simplify (tree, operators)
45+ end
46+ @generated function dispatch_deg2_simplify (
47+ tree:: Node{T} , operators:: AbstractOperatorEnum
48+ ) where {T}
49+ nbin = get_nbin (operators)
50+ quote
51+ op_idx = tree. op
52+ return Base. Cartesian. @nif (
53+ $ nbin, i -> i == op_idx, i -> _combine_operators_on (operators. binops[i], tree)
54+ )
3555 end
56+ end
3657
37- top_level_constant =
38- tree. degree == 2 && (is_node_constant (tree. l) || is_node_constant (tree. r))
39- if tree. degree == 2 && is_commutative (operators. binops[tree. op]) && top_level_constant
40- # TODO : Does this break SymbolicRegression.jl due to the different names of operators?
41- op = tree. op
42- # Put the constant in r. Need to assume var in left for simplification assumption.
43- if is_node_constant (tree. l)
44- tmp = tree. r
45- tree. r = tree. l
46- tree. l = tmp
58+ function _combine_operators_on (f:: F , tree:: Node{T} ) where {F,T}
59+ # NOTE: This assumes tree.degree == 2 and tree.op corresponds to f
60+ # Handle basic simplifications first
61+ if is_node_constant (tree. r)
62+ rval = tree. r. val
63+ if rval == zero (T)
64+ # Operations where right zero is identity (x + 0 -> x, x - 0 -> x)
65+ if has_right_identity_zero (f)
66+ return tree. l
67+ end
68+ # Operations that are absorbed by zero (x * 0 -> 0)
69+ if absorbed_by_zero (f)
70+ return tree. r
71+ end
72+ elseif rval == one (T)
73+ # x * 1 -> x
74+ if has_identity_one (f)
75+ return tree. l
76+ end
4777 end
48- topconstant = tree. r. val
49- # Simplify down first
50- below = tree. l
51- if below. degree == 2 && below. op == op
52- if is_node_constant (below. l)
53- tree = below
54- tree. l. val = _bin_op_kernel (operators. binops[op], tree. l. val, topconstant)
55- elseif is_node_constant (below. r)
56- tree = below
57- tree. r. val = _bin_op_kernel (operators. binops[op], tree. r. val, topconstant)
78+ end
79+
80+ if is_node_constant (tree. l)
81+ lval = tree. l. val
82+ if lval == zero (T)
83+ # Operations where left zero is identity (0 + x -> x)
84+ if has_left_identity_zero (f)
85+ return tree. r
5886 end
87+ # Operations that are absorbed by zero (0 * x -> 0)
88+ if absorbed_by_zero (f)
89+ return tree. l
90+ end
91+ elseif lval == one (T) && has_identity_one (f)
92+ # 1 * x -> x
93+ return tree. r
5994 end
6095 end
6196
62- if tree. degree == 2 && is_subtraction (operators. binops[tree. op]) && top_level_constant
97+ # x/x -> 1, or other self-simplifying operations
98+ if simplifies_given_equal_operands (f) && tree. l == tree. r
99+ return constructorof (typeof (tree))(T; val= one (T))
100+ end
63101
64- # Currently just simplifies subtraction. (can't assume both plus and sub are operators)
65- # Not commutative, so use different op.
102+ # Handle commutative operations with constants
103+ if is_commutative (f) && (is_node_constant (tree. l) || is_node_constant (tree. r))
104+ # Put constant on right for consistent handling
66105 if is_node_constant (tree. l)
67- if tree. r. degree == 2 && tree. op == tree. r. op
68- if is_node_constant (tree. r. l)
69- # (const - (const - var)) => (var - const)
70- l = tree. l
71- r = tree. r
72- simplified_const = (r. l. val - l. val) # neg(sub(l.val, r.l.val))
73- tree. l = tree. r. r
74- tree. r = l
75- tree. r. val = simplified_const
76- elseif is_node_constant (tree. r. r)
77- # (const - (var - const)) => (const - var)
78- l = tree. l
79- r = tree. r
80- simplified_const = l. val + r. r. val # plus(l.val, r.r.val)
81- tree. r = tree. r. l
82- tree. l. val = simplified_const
83- end
84- end
85- else # tree.r is a constant
86- if tree. l. degree == 2 && tree. op == tree. l. op
87- if is_node_constant (tree. l. l)
88- # ((const - var) - const) => (const - var)
89- l = tree. l
90- r = tree. r
91- simplified_const = l. l. val - r. val# sub(l.l.val, r.val)
92- tree. r = tree. l. r
93- tree. l = r
94- tree. l. val = simplified_const
95- elseif is_node_constant (tree. l. r)
96- # ((var - const) - const) => (var - const)
97- l = tree. l
98- r = tree. r
99- simplified_const = r. val + l. r. val # plus(r.val, l.r.val)
100- tree. l = tree. l. l
101- tree. r. val = simplified_const
102- end
106+ tree. l, tree. r = tree. r, tree. l
107+ end
108+
109+ # Now we know tree.r is constant
110+ below = tree. l
111+ if below. degree == 2 && below. op == tree. op
112+ # Combine nested operations with constants: ((a * x) * b) -> ((a*b) * x)
113+ if is_node_constant (below. l)
114+ below. l. val = _bin_op_kernel (f, below. l. val, tree. r. val)
115+ return below
116+ elseif is_node_constant (below. r)
117+ below. r. val = _bin_op_kernel (f, below. r. val, tree. r. val)
118+ return below
103119 end
104120 end
105121 end
122+
106123 return tree
107124end
108125
@@ -121,7 +138,6 @@ function combine_children!(operators, p::N, c::N...) where {T,N<:AbstractExpress
121138 return p
122139end
123140
124- # Simplify tree
125141function simplify_tree! (tree:: AbstractExpressionNode , operators:: AbstractOperatorEnum )
126142 return tree_mapreduce (
127143 identity, (p, c... ) -> combine_children! (operators, p, c... ), tree, typeof (tree);
0 commit comments