Skip to content

Commit 7b51c06

Browse files
committed
feat: n-arity compat with simplification
1 parent 5f977ab commit 7b51c06

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/Evaluate.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ function deg0_eval(
332332
end
333333
end
334334

335+
# This basically forms an if statement over the operators for the degree.
335336
@generated function inner_dispatch_degn_eval(
336337
tree::AbstractExpressionNode{T},
337338
cX::AbstractMatrix{T},
@@ -359,6 +360,8 @@ end
359360
)
360361
end
361362
end
363+
364+
# This forms an if statement over the degree of a given node.
362365
@generated function dispatch_degn_eval(
363366
tree::AbstractExpressionNode{T},
364367
cX::AbstractMatrix{T},

src/Simplify.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ import ..NodeUtilsModule: tree_mapreduce, is_node_constant
55
import ..OperatorEnumModule: AbstractOperatorEnum
66
import ..ValueInterfaceModule: is_valid
77

8-
_una_op_kernel(f::F, l::T) where {F,T} = f(l)
9-
_bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r)
8+
_op_kernel(f::F, l::T, ls::T...) where {F,T} = f(l, ls...)
109

1110
is_commutative(::typeof(*)) = true
1211
is_commutative(::typeof(+)) = true
@@ -17,8 +16,8 @@ is_subtraction(_) = false
1716

1817
combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
1918
# This is only defined for `Node` as it is not possible for, e.g.,
20-
# `GraphNode`.
21-
function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
19+
# `GraphNode`, and n-arity nodes.
20+
function combine_operators(tree::Node{T,2}, operators::AbstractOperatorEnum) where {T}
2221
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.
2322
# ((const + var) + const) => (const + var)
2423
# ((const * var) * const) => (const * var)
@@ -51,10 +50,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
5150
if below.degree == 2 && below.op == op
5251
if is_node_constant(below.l)
5352
tree = below
54-
tree.l.val = _bin_op_kernel(operators.binops[op], tree.l.val, topconstant)
53+
tree.l.val = _op_kernel(operators.binops[op], tree.l.val, topconstant)
5554
elseif is_node_constant(below.r)
5655
tree = below
57-
tree.r.val = _bin_op_kernel(operators.binops[op], tree.r.val, topconstant)
56+
tree.r.val = _op_kernel(operators.binops[op], tree.r.val, topconstant)
5857
end
5958
end
6059
end
@@ -106,15 +105,13 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
106105
return tree
107106
end
108107

109-
function combine_children!(operators, p::N, c::N...) where {T,N<:AbstractExpressionNode{T}}
108+
function combine_children!(
109+
operators, p::N, c::Vararg{N,degree}
110+
) where {T,N<:AbstractExpressionNode{T},degree}
110111
all(is_node_constant, c) || return p
111112
vals = map(n -> n.val, c)
112113
all(is_valid, vals) || return p
113-
out = if length(c) == 1
114-
_una_op_kernel(operators.unaops[p.op], vals...)
115-
else
116-
_bin_op_kernel(operators.binops[p.op], vals...)
117-
end
114+
out = _op_kernel(operators[degree][p.op], vals...)
118115
is_valid(out) || return p
119116
new_node = constructorof(N)(T; val=convert(T, out))
120117
set_node!(p, new_node)

0 commit comments

Comments
 (0)