|
159 | 159 | tree = (Node(; val=2.0) * x1) + Node(; val=3.0) |
160 | 160 | @test combine_operators!(tree, operators) == tree |
161 | 161 | end |
| 162 | + |
| 163 | +@testitem "Random tree simplification" begin |
| 164 | + using DynamicExpressions, Test |
| 165 | + import DynamicExpressions.SimplifyModule: combine_operators!, simplify_tree! |
| 166 | + import Random: MersenneTwister |
| 167 | + include("tree_gen_utils.jl") |
| 168 | + |
| 169 | + operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(sin, cos)) |
| 170 | + |
| 171 | + initial_sizes = Int[] |
| 172 | + simplified_sizes = Int[] |
| 173 | + |
| 174 | + for i in 1:100 |
| 175 | + rng = MersenneTwister(i) |
| 176 | + |
| 177 | + # Generate a random tree with 3 features and size ~50 nodes |
| 178 | + tree = gen_random_tree_fixed_size(50, operators, 3, Float64, Node, rng) |
| 179 | + |
| 180 | + # Randomly set some nodes to 0 or 1 |
| 181 | + if rand(rng) < 0.5 |
| 182 | + any(tree) do node |
| 183 | + if node.degree == 0 && node.constant && rand(rng) < 0.5 |
| 184 | + node.val = rand(rng) < 0.5 ? 0.0 : 1.0 |
| 185 | + true |
| 186 | + else |
| 187 | + false |
| 188 | + end |
| 189 | + end |
| 190 | + end |
| 191 | + |
| 192 | + # Simplify it |
| 193 | + simplified = combine_operators!(copy(tree), operators) |
| 194 | + |
| 195 | + # Simplified tree should not be larger than original |
| 196 | + push!(initial_sizes, count_nodes(tree)) |
| 197 | + push!(simplified_sizes, count_nodes(simplified)) |
| 198 | + |
| 199 | + # Evaluate both trees on the same output |
| 200 | + X = randn(rng, Float64, 3, 10) |
| 201 | + output1, flag1 = eval_tree_array(tree, X, operators) |
| 202 | + output2, flag2 = eval_tree_array(simplified, X, operators) |
| 203 | + |
| 204 | + # Both should succeed or fail together |
| 205 | + @test flag1 == flag2 |
| 206 | + |
| 207 | + if flag1 && flag2 |
| 208 | + # Results should be approximately equal |
| 209 | + @test isapprox(output1, output2, rtol=1e-10) |
| 210 | + end |
| 211 | + end |
| 212 | + # At least SOME should simplify |
| 213 | + @test any(i -> initial_sizes[i] > simplified_sizes[i], 1:length(initial_sizes)) |
| 214 | +end |
0 commit comments