Skip to content

Commit 3626d8f

Browse files
committed
feat: additional simplification types and generality
1 parent ca30514 commit 3626d8f

File tree

3 files changed

+164
-102
lines changed

3 files changed

+164
-102
lines changed

src/Simplify.jl

Lines changed: 93 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,105 +4,122 @@ import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set
44
import ..NodeUtilsModule: tree_mapreduce, is_node_constant
55
import ..OperatorEnumModule: AbstractOperatorEnum
66
import ..ValueInterfaceModule: is_valid
7+
import ..EvaluateModule: get_nbin
78

89
_una_op_kernel(f::F, l::T) where {F,T} = f(l)
910
_bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r)
1011

12+
# Operator traits
1113
is_commutative(::typeof(*)) = true
1214
is_commutative(::typeof(+)) = true
1315
is_commutative(_) = false
1416

15-
is_subtraction(::typeof(-)) = true
16-
is_subtraction(_) = false
17+
# Zero-related traits
18+
has_right_identity_zero(::typeof(+)) = true
19+
has_right_identity_zero(::typeof(-)) = true
20+
has_right_identity_zero(_) = false
21+
22+
has_left_identity_zero(::typeof(+)) = true
23+
has_left_identity_zero(_) = false
24+
25+
absorbed_by_zero(::typeof(*)) = true
26+
absorbed_by_zero(_) = false
27+
28+
# One-related traits
29+
has_identity_one(::typeof(*)) = true
30+
has_identity_one(_) = false
31+
32+
# Self-operation traits
33+
simplifies_given_equal_operands(::typeof(/)) = true
34+
simplifies_given_equal_operands(_) = false
1735

1836
combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
19-
# This is only defined for `Node` as it is not possible for, e.g.,
20-
# `GraphNode`.
37+
2138
function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
22-
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.
23-
# ((const + var) + const) => (const + var)
24-
# ((const * var) * const) => (const * var)
25-
# ((const - var) - const) => (const - var)
26-
# (want to add anything commutative!)
27-
# TODO - need to combine plus/sub if they are both there.
28-
if tree.degree == 0
29-
return tree
30-
elseif tree.degree == 1
31-
tree.l = combine_operators(tree.l, operators)
32-
elseif tree.degree == 2
33-
tree.l = combine_operators(tree.l, operators)
34-
tree.r = combine_operators(tree.r, operators)
39+
deg = tree.degree
40+
deg == 0 && return tree
41+
tree.l = combine_operators(tree.l, operators)
42+
deg == 1 && return tree
43+
tree.r = combine_operators(tree.r, operators)
44+
return dispatch_deg2_simplify(tree, operators)
45+
end
46+
@generated function dispatch_deg2_simplify(
47+
tree::Node{T}, operators::AbstractOperatorEnum
48+
) where {T}
49+
nbin = get_nbin(operators)
50+
quote
51+
op_idx = tree.op
52+
return Base.Cartesian.@nif(
53+
$nbin, i -> i == op_idx, i -> _combine_operators_on(operators.binops[i], tree)
54+
)
3555
end
56+
end
3657

37-
top_level_constant =
38-
tree.degree == 2 && (is_node_constant(tree.l) || is_node_constant(tree.r))
39-
if tree.degree == 2 && is_commutative(operators.binops[tree.op]) && top_level_constant
40-
# TODO: Does this break SymbolicRegression.jl due to the different names of operators?
41-
op = tree.op
42-
# Put the constant in r. Need to assume var in left for simplification assumption.
43-
if is_node_constant(tree.l)
44-
tmp = tree.r
45-
tree.r = tree.l
46-
tree.l = tmp
58+
function _combine_operators_on(f::F, tree::Node{T}) where {F,T}
59+
# NOTE: This assumes tree.degree == 2 and tree.op corresponds to f
60+
# Handle basic simplifications first
61+
if is_node_constant(tree.r)
62+
rval = tree.r.val
63+
if rval == zero(T)
64+
# Operations where right zero is identity (x + 0 -> x, x - 0 -> x)
65+
if has_right_identity_zero(f)
66+
return tree.l
67+
end
68+
# Operations that are absorbed by zero (x * 0 -> 0)
69+
if absorbed_by_zero(f)
70+
return tree.r
71+
end
72+
elseif rval == one(T)
73+
# x * 1 -> x
74+
if has_identity_one(f)
75+
return tree.l
76+
end
4777
end
48-
topconstant = tree.r.val
49-
# Simplify down first
50-
below = tree.l
51-
if below.degree == 2 && below.op == op
52-
if is_node_constant(below.l)
53-
tree = below
54-
tree.l.val = _bin_op_kernel(operators.binops[op], tree.l.val, topconstant)
55-
elseif is_node_constant(below.r)
56-
tree = below
57-
tree.r.val = _bin_op_kernel(operators.binops[op], tree.r.val, topconstant)
78+
end
79+
80+
if is_node_constant(tree.l)
81+
lval = tree.l.val
82+
if lval == zero(T)
83+
# Operations where left zero is identity (0 + x -> x)
84+
if has_left_identity_zero(f)
85+
return tree.r
5886
end
87+
# Operations that are absorbed by zero (0 * x -> 0)
88+
if absorbed_by_zero(f)
89+
return tree.l
90+
end
91+
elseif lval == one(T) && has_identity_one(f)
92+
# 1 * x -> x
93+
return tree.r
5994
end
6095
end
6196

62-
if tree.degree == 2 && is_subtraction(operators.binops[tree.op]) && top_level_constant
97+
# x/x -> 1, or other self-simplifying operations
98+
if simplifies_given_equal_operands(f) && tree.l == tree.r
99+
return constructorof(typeof(tree))(T; val=one(T))
100+
end
63101

64-
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
65-
# Not commutative, so use different op.
102+
# Handle commutative operations with constants
103+
if is_commutative(f) && (is_node_constant(tree.l) || is_node_constant(tree.r))
104+
# Put constant on right for consistent handling
66105
if is_node_constant(tree.l)
67-
if tree.r.degree == 2 && tree.op == tree.r.op
68-
if is_node_constant(tree.r.l)
69-
#(const - (const - var)) => (var - const)
70-
l = tree.l
71-
r = tree.r
72-
simplified_const = (r.l.val - l.val) #neg(sub(l.val, r.l.val))
73-
tree.l = tree.r.r
74-
tree.r = l
75-
tree.r.val = simplified_const
76-
elseif is_node_constant(tree.r.r)
77-
#(const - (var - const)) => (const - var)
78-
l = tree.l
79-
r = tree.r
80-
simplified_const = l.val + r.r.val #plus(l.val, r.r.val)
81-
tree.r = tree.r.l
82-
tree.l.val = simplified_const
83-
end
84-
end
85-
else #tree.r is a constant
86-
if tree.l.degree == 2 && tree.op == tree.l.op
87-
if is_node_constant(tree.l.l)
88-
#((const - var) - const) => (const - var)
89-
l = tree.l
90-
r = tree.r
91-
simplified_const = l.l.val - r.val#sub(l.l.val, r.val)
92-
tree.r = tree.l.r
93-
tree.l = r
94-
tree.l.val = simplified_const
95-
elseif is_node_constant(tree.l.r)
96-
#((var - const) - const) => (var - const)
97-
l = tree.l
98-
r = tree.r
99-
simplified_const = r.val + l.r.val #plus(r.val, l.r.val)
100-
tree.l = tree.l.l
101-
tree.r.val = simplified_const
102-
end
106+
tree.l, tree.r = tree.r, tree.l
107+
end
108+
109+
# Now we know tree.r is constant
110+
below = tree.l
111+
if below.degree == 2 && below.op == tree.op
112+
# Combine nested operations with constants: ((a * x) * b) -> ((a*b) * x)
113+
if is_node_constant(below.l)
114+
below.l.val = _bin_op_kernel(f, below.l.val, tree.r.val)
115+
return below
116+
elseif is_node_constant(below.r)
117+
below.r.val = _bin_op_kernel(f, below.r.val, tree.r.val)
118+
return below
103119
end
104120
end
105121
end
122+
106123
return tree
107124
end
108125

@@ -121,7 +138,6 @@ function combine_children!(operators, p::N, c::N...) where {T,N<:AbstractExpress
121138
return p
122139
end
123140

124-
# Simplify tree
125141
function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum)
126142
return tree_mapreduce(
127143
identity, (p, c...) -> combine_children!(operators, p, c...), tree, typeof(tree);

test/test_parametric_expression.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ end
161161

162162
(init_out, true_out) = map(
163163
p -> [
164-
(X[1, i] * p[2, classes[i]] + X[2, i] + p[1, classes[i]]) for i in 1:size(X, 2)
164+
(X[1, i] * p[2, classes[i]] + X[2, i] + p[1, classes[i]]) for
165+
i in eachindex(axes(X, 2))
165166
],
166167
(init_parameters, true_parameters),
167168
)

test/test_simplification.jl

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ end
5151
((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + (
5252
sin(
5353
custom_cos(
54-
sin(1.2926733 - 1.6606787) /
55-
sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426),
54+
sin(1.2926733 - 1.6606787) / sin(
55+
((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426
56+
),
5657
) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)),
5758
) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262))
5859
)
5960
)
60-
61+
6162
eqn = convert(Symbolic, tree, operators; index_functions=true)
6263
tree_copy = convert(Node, eqn, operators)
6364
tree_copy2 = convert(Node, simplify(eqn), operators)
@@ -89,28 +90,72 @@ end
8990
@test repr(simplify_tree!(tree, operators)) "cos(NaN)"
9091

9192
# Nested constant folding
92-
tree = Node(1, Node(1, Node(; val=0.1), Node(; val=0.2)) + Node(; val=0.2)) + Node(; val=2.0)
93+
tree =
94+
Node(1, Node(1, Node(; val=0.1), Node(; val=0.2)) + Node(; val=0.2)) +
95+
Node(; val=2.0)
9396
@test repr(tree) "(cos((0.1 + 0.2) + 0.2) + 2.0)"
9497
@test repr(combine_operators(tree, operators)) "(cos(0.4 + 0.1) + 2.0)"
9598
end
9699

97-
# (const - (const - var)) => (var - const)
98-
tree = Node(2, Node(; val=0.5), Node(; val=0.2) - x1)
99-
@test repr(tree) "(0.5 - (0.2 - x1))"
100-
@test repr(combine_operators(tree, operators)) "(x1 - -0.3)"
101-
102-
# ((const - var) - const) => (const - var)
103-
tree = Node(2, Node(; val=0.5) - x1, Node(; val=0.2))
104-
@test repr(tree) "((0.5 - x1) - 0.2)"
105-
@test repr(combine_operators(tree, operators)) "(0.3 - x1)"
106-
107-
# (const - (var - const)) => (const - var)
108-
tree = Node(2, Node(; val=0.5), x1 - Node(; val=0.2))
109-
@test repr(tree) "(0.5 - (x1 - 0.2))"
110-
@test repr(combine_operators(tree, operators)) "(0.7 - x1)"
111-
112-
# ((var - const) - const) => (var - const)
113-
tree = ((x1 - 0.2) - 0.6)
114-
@test repr(tree) "((x1 - 0.2) - 0.6)"
115-
@test repr(combine_operators(tree, operators)) "(x1 - 0.8)"
116-
###############################################################################
100+
@testitem "Basic operator simplifications" begin
101+
using DynamicExpressions, Test
102+
import DynamicExpressions.SimplifyModule: combine_operators
103+
104+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin))
105+
x = Node(; feature=1)
106+
zero_node = Node(; val=0.0)
107+
one_node = Node(; val=1.0)
108+
two_node = Node(; val=2.0)
109+
three_node = Node(; val=3.0)
110+
111+
# multiplication by 0
112+
tree = zero_node * x
113+
@test combine_operators(tree, operators) == zero_node
114+
tree = x * zero_node
115+
@test combine_operators(tree, operators) == zero_node
116+
117+
# multiplication by 1
118+
tree = one_node * x
119+
@test combine_operators(tree, operators) == x
120+
tree = x * one_node
121+
@test combine_operators(tree, operators) == x
122+
123+
# addition by 0
124+
tree = zero_node + x
125+
@test combine_operators(tree, operators) == x
126+
tree = x + zero_node
127+
@test combine_operators(tree, operators) == x
128+
129+
# division by self -> 1
130+
tree = x / x
131+
@test combine_operators(tree, operators).val == 1.0
132+
133+
# nested multiplication by constants
134+
tree1 = (two_node * x) * three_node
135+
tree2 = Node(; val=6.0) * x
136+
@test combine_operators(tree1, operators) == combine_operators(tree2, operators)
137+
end
138+
139+
@testitem "Constant combination" begin
140+
using DynamicExpressions, Test
141+
import DynamicExpressions.SimplifyModule: combine_operators
142+
143+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin))
144+
x1 = Node(; feature=1)
145+
146+
# Test commutative constant combination
147+
tree = Node(; val=0.5) + (Node(; val=0.2) + x1)
148+
@test combine_operators(tree, operators) == x1 + Node(; val=0.7)
149+
150+
# Test nested multiplication by constants
151+
tree = (Node(; val=2.0) * x1) * Node(; val=3.0)
152+
@test combine_operators(tree, operators) == x1 * Node(; val=6.0)
153+
154+
# Test nested addition by constants
155+
tree = (Node(; val=2.0) + x1) + Node(; val=3.0)
156+
@test combine_operators(tree, operators) == x1 + Node(; val=5.0)
157+
158+
# Test mixed operations don't combine incorrectly
159+
tree = (Node(; val=2.0) * x1) + Node(; val=3.0)
160+
@test combine_operators(tree, operators) == tree
161+
end

0 commit comments

Comments
 (0)