|
1 | | -include("test_params.jl") |
2 | | -using DynamicExpressions, Test |
3 | | -import DynamicExpressions.StringsModule: strip_brackets |
4 | | -import SymbolicUtils: simplify, Symbolic |
5 | | -import Random: MersenneTwister |
6 | | -import Base: ≈ |
7 | | - |
8 | | -strip_brackets(a::String) = String(strip_brackets(collect(a))) |
9 | | - |
10 | | -function Base.:≈(a::String, b::String) |
11 | | - a = strip_brackets(a) |
12 | | - b = strip_brackets(b) |
13 | | - a = replace(a, r"\s+" => "") |
14 | | - b = replace(b, r"\s+" => "") |
15 | | - return a == b |
| 1 | +@testitem "SymbolicUtils conversion" begin |
| 2 | + using DynamicExpressions, Test |
| 3 | + import DynamicExpressions.StringsModule: strip_brackets |
| 4 | + import SymbolicUtils: simplify, Symbolic |
| 5 | + import Base: ≈ |
| 6 | + |
| 7 | + strip_brackets(a::String) = String(strip_brackets(collect(a))) |
| 8 | + function Base.:≈(a::String, b::String) |
| 9 | + a = strip_brackets(a) |
| 10 | + b = strip_brackets(b) |
| 11 | + a = replace(a, r"\s+" => "") |
| 12 | + b = replace(b, r"\s+" => "") |
| 13 | + return a == b |
| 14 | + end |
| 15 | + |
| 16 | + operators = OperatorEnum(; binary_operators=(+, -, /, *)) |
| 17 | + tree = Node("x1") + Node("x1") |
| 18 | + |
| 19 | + # Should simplify to 2*x1: |
| 20 | + eqn = convert(Symbolic, tree, operators) |
| 21 | + eqn2 = simplify(eqn) |
| 22 | + @test occursin("2", "$(repr(eqn2)[1])") |
| 23 | + |
| 24 | + # Let's convert back the simplified version. |
| 25 | + tree = convert(Node, eqn2, operators) |
| 26 | + @test (tree.l.constant ? tree.l : tree.r).val == 2 |
| 27 | + @test (!tree.l.constant ? tree.l : tree.r).feature == 1 |
| 28 | + |
| 29 | + # Test that SymbolicUtils does not convert multiplication to power: |
| 30 | + tree = Node("x1") * Node("x1") |
| 31 | + eqn = convert(Symbolic, tree, operators) |
| 32 | + @test repr(eqn) ≈ "x1*x1" |
| 33 | + tree_copy = convert(Node, eqn, operators) |
| 34 | + @test repr(tree_copy) ≈ "(x1*x1)" |
16 | 35 | end |
17 | 36 |
|
18 | | -simplify_tree! = DynamicExpressions.SimplifyModule.simplify_tree! |
19 | | -combine_operators = DynamicExpressions.SimplifyModule.combine_operators |
20 | | - |
21 | | -binary_operators = (+, -, /, *) |
22 | | - |
23 | | -index_of_mult = [i for (i, op) in enumerate(binary_operators) if op == *][1] |
24 | | - |
25 | | -operators = OperatorEnum(; binary_operators=binary_operators) |
26 | | - |
27 | | -tree = Node("x1") + Node("x1") |
28 | | - |
29 | | -# Should simplify to 2*x1: |
30 | | -eqn = convert(Symbolic, tree, operators) |
31 | | -eqn2 = simplify(eqn) |
32 | | -# Should correctly simplify to 2 x1: |
33 | | -# (although it might use 2(x1^1)) |
34 | | -@test occursin("2", "$(repr(eqn2)[1])") |
35 | | - |
36 | | -# Let's convert back the simplified version. |
37 | | -# This should remove the ^ operator: |
38 | | -tree = convert(Node, eqn2, operators) |
39 | | -# Make sure one of the nodes is now 2.0: |
40 | | -@test (tree.l.constant ? tree.l : tree.r).val == 2 |
41 | | -# Make sure the other node is x1: |
42 | | -@test (!tree.l.constant ? tree.l : tree.r).feature == 1 |
43 | | - |
44 | | -# Finally, let's try converting a product, and ensure |
45 | | -# that SymbolicUtils does not convert it to a power: |
46 | | -tree = Node("x1") * Node("x1") |
47 | | -eqn = convert(Symbolic, tree, operators) |
48 | | -@test repr(eqn) ≈ "x1*x1" |
49 | | -# Test converting back: |
50 | | -tree_copy = convert(Node, eqn, operators) |
51 | | -@test repr(tree_copy) ≈ "(x1*x1)" |
52 | | - |
53 | | -# Let's test a much more complex function, |
54 | | -# with custom operators, and unary operators: |
55 | | -x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") |
56 | | -pow_abs2(x, y) = abs(x)^y |
57 | | - |
58 | | -operators = OperatorEnum(; |
59 | | - binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin) |
60 | | -) |
61 | | -@extend_operators operators |
62 | | -tree = ( |
63 | | - ((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + ( |
64 | | - sin( |
65 | | - custom_cos( |
66 | | - sin(1.2926733 - 1.6606787) / |
67 | | - sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426), |
68 | | - ) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)), |
69 | | - ) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262)) |
| 37 | +@testitem "Complex expression simplification" begin |
| 38 | + using DynamicExpressions, Test |
| 39 | + using SymbolicUtils: simplify, Symbolic |
| 40 | + import Random: MersenneTwister |
| 41 | + include("test_params.jl") |
| 42 | + |
| 43 | + x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") |
| 44 | + pow_abs2(x, y) = abs(x)^y |
| 45 | + |
| 46 | + operators = OperatorEnum(; |
| 47 | + binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin) |
| 48 | + ) |
| 49 | + @extend_operators operators |
| 50 | + tree = ( |
| 51 | + ((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + ( |
| 52 | + sin( |
| 53 | + custom_cos( |
| 54 | + sin(1.2926733 - 1.6606787) / |
| 55 | + sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426), |
| 56 | + ) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)), |
| 57 | + ) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262)) |
| 58 | + ) |
70 | 59 | ) |
71 | | -) |
72 | | -# We use `index_functions` to avoid converting the custom operators into the primitives. |
73 | | -eqn = convert(Symbolic, tree, operators; index_functions=true) |
74 | | - |
75 | | -tree_copy = convert(Node, eqn, operators) |
76 | | -tree_copy2 = convert(Node, simplify(eqn), operators) |
77 | | -# Too difficult to check the representation, so we check by evaluation: |
78 | | -N = 100 |
79 | | -X = rand(MersenneTwister(0), 3, N) .+ 0.1 |
80 | | -output1, flag1 = eval_tree_array(tree, X, operators) |
81 | | -output2, flag2 = eval_tree_array(tree_copy, X, operators) |
82 | | -output3, flag3 = eval_tree_array(tree_copy2, X, operators) |
83 | | - |
84 | | -@test isapprox(output1, output2, atol=1e-4 * sqrt(N)) |
85 | | -# Simplified equation may give a different answer due to rounding errors, |
86 | | -# so we weaken the requirement: |
87 | | -@test isapprox(output1, output3, atol=1e-2 * sqrt(N)) |
| 60 | + |
| 61 | + eqn = convert(Symbolic, tree, operators; index_functions=true) |
| 62 | + tree_copy = convert(Node, eqn, operators) |
| 63 | + tree_copy2 = convert(Node, simplify(eqn), operators) |
| 64 | + |
| 65 | + N = 100 |
| 66 | + X = rand(MersenneTwister(0), 3, N) .+ 0.1 |
| 67 | + output1, flag1 = eval_tree_array(tree, X, operators) |
| 68 | + output2, flag2 = eval_tree_array(tree_copy, X, operators) |
| 69 | + output3, flag3 = eval_tree_array(tree_copy2, X, operators) |
| 70 | + |
| 71 | + @test isapprox(output1, output2, atol=1e-4 * sqrt(N)) |
| 72 | + @test isapprox(output1, output3, atol=1e-2 * sqrt(N)) |
| 73 | +end |
88 | 74 |
|
89 | | -############################################################################### |
90 | | -## Hit other parts of `simplify_tree!` and `combine_operators` to increase |
91 | | -## code coverage: |
92 | | -operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin)) |
93 | | -x1, x2, x3 = [Node(; feature=i) for i in 1:3] |
94 | | - |
95 | | -# unary operator applied to constant => constant: |
96 | | -tree = Node(1, Node(; val=0.0)) |
97 | | -@test repr(tree) ≈ "cos(0.0)" |
98 | | -@test repr(simplify_tree!(tree, operators)) ≈ "1.0" |
99 | | - |
100 | | -# except when the result is a NaN, then we don't change it: |
101 | | -tree = Node(1, Node(; val=NaN)) |
102 | | -@test repr(tree) ≈ "cos(NaN)" |
103 | | -@test repr(simplify_tree!(tree, operators)) ≈ "cos(NaN)" |
104 | | - |
105 | | -# the same as above, but inside a binary tree. |
106 | | -tree = |
107 | | - Node(1, Node(1, Node(; val=0.1), Node(; val=0.2)) + Node(; val=0.2)) + Node(; val=2.0) |
108 | | -@test repr(tree) ≈ "(cos((0.1 + 0.2) + 0.2) + 2.0)" |
109 | | -@test repr(combine_operators(tree, operators)) ≈ "(cos(0.4 + 0.1) + 2.0)" |
110 | | - |
111 | | -# left is constant: |
112 | | -tree = Node(; val=0.5) + (Node(; val=0.2) + x1) |
113 | | -@test repr(tree) ≈ "(0.5 + (0.2 + x1))" |
114 | | -@test repr(combine_operators(tree, operators)) ≈ "(x1 + 0.7)" |
| 75 | +@testitem "Constant folding" begin |
| 76 | + using DynamicExpressions, Test |
| 77 | + |
| 78 | + operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin)) |
| 79 | + x1, x2, x3 = [Node(; feature=i) for i in 1:3] |
| 80 | + |
| 81 | + # Unary operator applied to constant => constant: |
| 82 | + tree = Node(1, Node(; val=0.0)) |
| 83 | + @test repr(tree) ≈ "cos(0.0)" |
| 84 | + @test repr(simplify_tree!(tree, operators)) ≈ "1.0" |
| 85 | + |
| 86 | + # NaN handling |
| 87 | + tree = Node(1, Node(; val=NaN)) |
| 88 | + @test repr(tree) ≈ "cos(NaN)" |
| 89 | + @test repr(simplify_tree!(tree, operators)) ≈ "cos(NaN)" |
| 90 | + |
| 91 | + # Nested constant folding |
| 92 | + tree = Node(1, Node(1, Node(; val=0.1), Node(; val=0.2)) + Node(; val=0.2)) + Node(; val=2.0) |
| 93 | + @test repr(tree) ≈ "(cos((0.1 + 0.2) + 0.2) + 2.0)" |
| 94 | + @test repr(combine_operators(tree, operators)) ≈ "(cos(0.4 + 0.1) + 2.0)" |
| 95 | +end |
115 | 96 |
|
116 | 97 | # (const - (const - var)) => (var - const) |
117 | 98 | tree = Node(2, Node(; val=0.5), Node(; val=0.2) - x1) |
|
0 commit comments