Skip to content

Commit 74f03df

Browse files
committed
Function barriers in simplification
1 parent d99937c commit 74f03df

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/SimplifyEquation.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import ..EquationModule: Node, copy_node
44
import ..OperatorEnumModule: AbstractOperatorEnum
55
import ..UtilsModule: isbad, isgood
66

7+
_una_op_kernel(f::F, l::T) where {F,T} = f(l)
8+
_bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r)
9+
710
# Simplify tree
811
function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
912
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree before.
@@ -41,10 +44,14 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
4144
if below.degree == 2 && below.op == op
4245
if below.l.constant
4346
tree = below
44-
tree.l.val = operators.binops[op](tree.l.val::T, topconstant)
47+
tree.l.val = _bin_op_kernel(
48+
operators.binops[op], tree.l.val::T, topconstant
49+
)
4550
elseif below.r.constant
4651
tree = below
47-
tree.r.val = operators.binops[op](tree.r.val::T, topconstant)
52+
tree.r.val = _bin_op_kernel(
53+
operators.binops[op], tree.r.val::T, topconstant
54+
)
4855
end
4956
end
5057
end
@@ -96,13 +103,14 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
96103
end
97104

98105
# Simplify tree
106+
# TODO: This will get much more powerful with the tree-map functions.
99107
function simplify_tree(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
100108
if tree.degree == 1
101109
tree.l = simplify_tree(tree.l, operators)
102110
if tree.l.degree == 0 && tree.l.constant
103111
l = tree.l.val::T
104112
if isgood(l)
105-
out = operators.unaops[tree.op](l)
113+
out = _una_op_kernel(operators.unaops[tree.op], l)
106114
if isbad(out)
107115
return tree
108116
end
@@ -124,7 +132,7 @@ function simplify_tree(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
124132
end
125133

126134
# Actually compute:
127-
out = operators.binops[tree.op](l, r)
135+
out = _bin_op_kernel(operators.binops[tree.op], l, r)
128136
if isbad(out)
129137
return tree
130138
end

0 commit comments

Comments
 (0)