Skip to content

Commit ca30514

Browse files
committed
test: move simplification tests to testtem format
1 parent da31b90 commit ca30514

File tree

2 files changed

+92
-113
lines changed

2 files changed

+92
-113
lines changed

test/test_simplification.jl

Lines changed: 91 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,98 @@
1-
include("test_params.jl")
2-
using DynamicExpressions, Test
3-
import DynamicExpressions.StringsModule: strip_brackets
4-
import SymbolicUtils: simplify, Symbolic
5-
import Random: MersenneTwister
6-
import Base:
7-
8-
strip_brackets(a::String) = String(strip_brackets(collect(a)))
9-
10-
function Base.:(a::String, b::String)
11-
a = strip_brackets(a)
12-
b = strip_brackets(b)
13-
a = replace(a, r"\s+" => "")
14-
b = replace(b, r"\s+" => "")
15-
return a == b
1+
@testitem "SymbolicUtils conversion" begin
2+
using DynamicExpressions, Test
3+
import DynamicExpressions.StringsModule: strip_brackets
4+
import SymbolicUtils: simplify, Symbolic
5+
import Base:
6+
7+
strip_brackets(a::String) = String(strip_brackets(collect(a)))
8+
function Base.:(a::String, b::String)
9+
a = strip_brackets(a)
10+
b = strip_brackets(b)
11+
a = replace(a, r"\s+" => "")
12+
b = replace(b, r"\s+" => "")
13+
return a == b
14+
end
15+
16+
operators = OperatorEnum(; binary_operators=(+, -, /, *))
17+
tree = Node("x1") + Node("x1")
18+
19+
# Should simplify to 2*x1:
20+
eqn = convert(Symbolic, tree, operators)
21+
eqn2 = simplify(eqn)
22+
@test occursin("2", "$(repr(eqn2)[1])")
23+
24+
# Let's convert back the simplified version.
25+
tree = convert(Node, eqn2, operators)
26+
@test (tree.l.constant ? tree.l : tree.r).val == 2
27+
@test (!tree.l.constant ? tree.l : tree.r).feature == 1
28+
29+
# Test that SymbolicUtils does not convert multiplication to power:
30+
tree = Node("x1") * Node("x1")
31+
eqn = convert(Symbolic, tree, operators)
32+
@test repr(eqn) "x1*x1"
33+
tree_copy = convert(Node, eqn, operators)
34+
@test repr(tree_copy) "(x1*x1)"
1635
end
1736

18-
simplify_tree! = DynamicExpressions.SimplifyModule.simplify_tree!
19-
combine_operators = DynamicExpressions.SimplifyModule.combine_operators
20-
21-
binary_operators = (+, -, /, *)
22-
23-
index_of_mult = [i for (i, op) in enumerate(binary_operators) if op == *][1]
24-
25-
operators = OperatorEnum(; binary_operators=binary_operators)
26-
27-
tree = Node("x1") + Node("x1")
28-
29-
# Should simplify to 2*x1:
30-
eqn = convert(Symbolic, tree, operators)
31-
eqn2 = simplify(eqn)
32-
# Should correctly simplify to 2 x1:
33-
# (although it might use 2(x1^1))
34-
@test occursin("2", "$(repr(eqn2)[1])")
35-
36-
# Let's convert back the simplified version.
37-
# This should remove the ^ operator:
38-
tree = convert(Node, eqn2, operators)
39-
# Make sure one of the nodes is now 2.0:
40-
@test (tree.l.constant ? tree.l : tree.r).val == 2
41-
# Make sure the other node is x1:
42-
@test (!tree.l.constant ? tree.l : tree.r).feature == 1
43-
44-
# Finally, let's try converting a product, and ensure
45-
# that SymbolicUtils does not convert it to a power:
46-
tree = Node("x1") * Node("x1")
47-
eqn = convert(Symbolic, tree, operators)
48-
@test repr(eqn) "x1*x1"
49-
# Test converting back:
50-
tree_copy = convert(Node, eqn, operators)
51-
@test repr(tree_copy) "(x1*x1)"
52-
53-
# Let's test a much more complex function,
54-
# with custom operators, and unary operators:
55-
x1, x2, x3 = Node("x1"), Node("x2"), Node("x3")
56-
pow_abs2(x, y) = abs(x)^y
57-
58-
operators = OperatorEnum(;
59-
binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin)
60-
)
61-
@extend_operators operators
62-
tree = (
63-
((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + (
64-
sin(
65-
custom_cos(
66-
sin(1.2926733 - 1.6606787) /
67-
sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426),
68-
) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)),
69-
) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262))
37+
@testitem "Complex expression simplification" begin
38+
using DynamicExpressions, Test
39+
using SymbolicUtils: simplify, Symbolic
40+
import Random: MersenneTwister
41+
include("test_params.jl")
42+
43+
x1, x2, x3 = Node("x1"), Node("x2"), Node("x3")
44+
pow_abs2(x, y) = abs(x)^y
45+
46+
operators = OperatorEnum(;
47+
binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin)
48+
)
49+
@extend_operators operators
50+
tree = (
51+
((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + (
52+
sin(
53+
custom_cos(
54+
sin(1.2926733 - 1.6606787) /
55+
sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426),
56+
) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)),
57+
) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262))
58+
)
7059
)
71-
)
72-
# We use `index_functions` to avoid converting the custom operators into the primitives.
73-
eqn = convert(Symbolic, tree, operators; index_functions=true)
74-
75-
tree_copy = convert(Node, eqn, operators)
76-
tree_copy2 = convert(Node, simplify(eqn), operators)
77-
# Too difficult to check the representation, so we check by evaluation:
78-
N = 100
79-
X = rand(MersenneTwister(0), 3, N) .+ 0.1
80-
output1, flag1 = eval_tree_array(tree, X, operators)
81-
output2, flag2 = eval_tree_array(tree_copy, X, operators)
82-
output3, flag3 = eval_tree_array(tree_copy2, X, operators)
83-
84-
@test isapprox(output1, output2, atol=1e-4 * sqrt(N))
85-
# Simplified equation may give a different answer due to rounding errors,
86-
# so we weaken the requirement:
87-
@test isapprox(output1, output3, atol=1e-2 * sqrt(N))
60+
61+
eqn = convert(Symbolic, tree, operators; index_functions=true)
62+
tree_copy = convert(Node, eqn, operators)
63+
tree_copy2 = convert(Node, simplify(eqn), operators)
64+
65+
N = 100
66+
X = rand(MersenneTwister(0), 3, N) .+ 0.1
67+
output1, flag1 = eval_tree_array(tree, X, operators)
68+
output2, flag2 = eval_tree_array(tree_copy, X, operators)
69+
output3, flag3 = eval_tree_array(tree_copy2, X, operators)
70+
71+
@test isapprox(output1, output2, atol=1e-4 * sqrt(N))
72+
@test isapprox(output1, output3, atol=1e-2 * sqrt(N))
73+
end
8874

89-
###############################################################################
90-
## Hit other parts of `simplify_tree!` and `combine_operators` to increase
91-
## code coverage:
92-
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin))
93-
x1, x2, x3 = [Node(; feature=i) for i in 1:3]
94-
95-
# unary operator applied to constant => constant:
96-
tree = Node(1, Node(; val=0.0))
97-
@test repr(tree) "cos(0.0)"
98-
@test repr(simplify_tree!(tree, operators)) "1.0"
99-
100-
# except when the result is a NaN, then we don't change it:
101-
tree = Node(1, Node(; val=NaN))
102-
@test repr(tree) "cos(NaN)"
103-
@test repr(simplify_tree!(tree, operators)) "cos(NaN)"
104-
105-
# the same as above, but inside a binary tree.
106-
tree =
107-
Node(1, Node(1, Node(; val=0.1), Node(; val=0.2)) + Node(; val=0.2)) + Node(; val=2.0)
108-
@test repr(tree) "(cos((0.1 + 0.2) + 0.2) + 2.0)"
109-
@test repr(combine_operators(tree, operators)) "(cos(0.4 + 0.1) + 2.0)"
110-
111-
# left is constant:
112-
tree = Node(; val=0.5) + (Node(; val=0.2) + x1)
113-
@test repr(tree) "(0.5 + (0.2 + x1))"
114-
@test repr(combine_operators(tree, operators)) "(x1 + 0.7)"
75+
@testitem "Constant folding" begin
76+
using DynamicExpressions, Test
77+
78+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin))
79+
x1, x2, x3 = [Node(; feature=i) for i in 1:3]
80+
81+
# Unary operator applied to constant => constant:
82+
tree = Node(1, Node(; val=0.0))
83+
@test repr(tree) "cos(0.0)"
84+
@test repr(simplify_tree!(tree, operators)) "1.0"
85+
86+
# NaN handling
87+
tree = Node(1, Node(; val=NaN))
88+
@test repr(tree) "cos(NaN)"
89+
@test repr(simplify_tree!(tree, operators)) "cos(NaN)"
90+
91+
# Nested constant folding
92+
tree = Node(1, Node(1, Node(; val=0.1), Node(; val=0.2)) + Node(; val=0.2)) + Node(; val=2.0)
93+
@test repr(tree) "(cos((0.1 + 0.2) + 0.2) + 2.0)"
94+
@test repr(combine_operators(tree, operators)) "(cos(0.4 + 0.1) + 2.0)"
95+
end
11596

11697
# (const - (const - var)) => (var - const)
11798
tree = Node(2, Node(; val=0.5), Node(; val=0.2) - x1)

test/unittest.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ end
3838
include("test_undefined_derivatives.jl")
3939
end
4040

41-
@testitem "Test simplification" begin
42-
include("test_simplification.jl")
43-
end
41+
include("test_simplification.jl")
4442

4543
@testitem "Test printing" begin
4644
include("test_print.jl")

0 commit comments

Comments
 (0)