Skip to content

Commit c860b5a

Browse files
committed
test: random tests of simplification utilities
1 parent 4cb5a76 commit c860b5a

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

test/test_expressions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ end
109109
operators = OperatorEnum(; binary_operators=[+, -]),
110110
variable_names = [:x],
111111
)
112-
out = combine_operators(ex)
112+
out = combine_operators!(ex)
113113
@test typeof(out) === typeof(ex)
114114
@test string_tree(out) == "x + 5.0"
115115
end
@@ -375,7 +375,7 @@ end
375375
complex_expr = parse_expression(
376376
:((2.0 + x) + 3.0); operators=operators, variable_names=["x"]
377377
)
378-
simplified_expr = combine_operators(copy(complex_expr))
378+
simplified_expr = combine_operators!(copy(complex_expr))
379379
println("Original: ", complex_expr)
380380
println("Simplified: ", simplified_expr)
381381

test/test_simplification.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,56 @@ end
159159
tree = (Node(; val=2.0) * x1) + Node(; val=3.0)
160160
@test combine_operators!(tree, operators) == tree
161161
end
162+
163+
@testitem "Random tree simplification" begin
164+
using DynamicExpressions, Test
165+
import DynamicExpressions.SimplifyModule: combine_operators!, simplify_tree!
166+
import Random: MersenneTwister
167+
include("tree_gen_utils.jl")
168+
169+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(sin, cos))
170+
171+
initial_sizes = Int[]
172+
simplified_sizes = Int[]
173+
174+
for i in 1:100
175+
rng = MersenneTwister(i)
176+
177+
# Generate a random tree with 3 features and size ~50 nodes
178+
tree = gen_random_tree_fixed_size(50, operators, 3, Float64, Node, rng)
179+
180+
# Randomly set some nodes to 0 or 1
181+
if rand(rng) < 0.5
182+
any(tree) do node
183+
if node.degree == 0 && node.constant && rand(rng) < 0.5
184+
node.val = rand(rng) < 0.5 ? 0.0 : 1.0
185+
true
186+
else
187+
false
188+
end
189+
end
190+
end
191+
192+
# Simplify it
193+
simplified = combine_operators!(copy(tree), operators)
194+
195+
# Simplified tree should not be larger than original
196+
push!(initial_sizes, count_nodes(tree))
197+
push!(simplified_sizes, count_nodes(simplified))
198+
199+
# Evaluate both trees on the same output
200+
X = randn(rng, Float64, 3, 10)
201+
output1, flag1 = eval_tree_array(tree, X, operators)
202+
output2, flag2 = eval_tree_array(simplified, X, operators)
203+
204+
# Both should succeed or fail together
205+
@test flag1 == flag2
206+
207+
if flag1 && flag2
208+
# Results should be approximately equal
209+
@test isapprox(output1, output2, rtol=1e-10)
210+
end
211+
end
212+
# At least SOME should simplify
213+
@test any(i -> initial_sizes[i] > simplified_sizes[i], 1:length(initial_sizes))
214+
end

0 commit comments

Comments
 (0)