@@ -5,8 +5,7 @@ import ..NodeUtilsModule: tree_mapreduce, is_node_constant
55import .. OperatorEnumModule: AbstractOperatorEnum
66import .. ValueInterfaceModule: is_valid
77
8- _una_op_kernel (f:: F , l:: T ) where {F,T} = f (l)
9- _bin_op_kernel (f:: F , l:: T , r:: T ) where {F,T} = f (l, r)
8+ _op_kernel (f:: F , l:: T , ls:: T... ) where {F,T} = f (l, ls... )
109
1110is_commutative (:: typeof (* )) = true
1211is_commutative (:: typeof (+ )) = true
@@ -17,8 +16,8 @@ is_subtraction(_) = false
1716
1817combine_operators (tree:: AbstractExpressionNode , :: AbstractOperatorEnum ) = tree
1918# This is only defined for `Node` as it is not possible for, e.g.,
20- # `GraphNode`.
21- function combine_operators (tree:: Node{T} , operators:: AbstractOperatorEnum ) where {T}
19+ # `GraphNode`, and n-arity nodes .
20+ function combine_operators (tree:: Node{T,2 } , operators:: AbstractOperatorEnum ) where {T}
2221 # NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.
2322 # ((const + var) + const) => (const + var)
2423 # ((const * var) * const) => (const * var)
@@ -51,10 +50,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
5150 if below. degree == 2 && below. op == op
5251 if is_node_constant (below. l)
5352 tree = below
54- tree. l. val = _bin_op_kernel (operators. binops[op], tree. l. val, topconstant)
53+ tree. l. val = _op_kernel (operators. binops[op], tree. l. val, topconstant)
5554 elseif is_node_constant (below. r)
5655 tree = below
57- tree. r. val = _bin_op_kernel (operators. binops[op], tree. r. val, topconstant)
56+ tree. r. val = _op_kernel (operators. binops[op], tree. r. val, topconstant)
5857 end
5958 end
6059 end
@@ -106,15 +105,13 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
106105 return tree
107106end
108107
109- function combine_children! (operators, p:: N , c:: N... ) where {T,N<: AbstractExpressionNode{T} }
108+ function combine_children! (
109+ operators, p:: N , c:: Vararg{N,degree}
110+ ) where {T,N<: AbstractExpressionNode{T} ,degree}
110111 all (is_node_constant, c) || return p
111112 vals = map (n -> n. val, c)
112113 all (is_valid, vals) || return p
113- out = if length (c) == 1
114- _una_op_kernel (operators. unaops[p. op], vals... )
115- else
116- _bin_op_kernel (operators. binops[p. op], vals... )
117- end
114+ out = _op_kernel (operators[degree][p. op], vals... )
118115 is_valid (out) || return p
119116 new_node = constructorof (N)(T; val= convert (T, out))
120117 set_node! (p, new_node)
0 commit comments