Skip to content

Commit 7d2a542

Browse files
committed
test: incorporate turbo and bumper in supposition test
1 parent 0af1700 commit 7d2a542

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

test/test_supposition_consistency.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
using DynamicExpressions
88
using DynamicExpressions:
99
string_tree, parse_expression, eval_tree_array, Node, get_operators, get_tree
10+
using LoopVectorization: LoopVectorization
11+
using Bumper: Bumper
1012

1113
# bring the generator into scope
1214
include("supposition_utils.jl")
@@ -20,7 +22,7 @@
2022
T; num_features=n_features, max_layers=max_layers, operators=operators
2123
)
2224

23-
@check function roundtrip_string(ex=expr_gen)
25+
result = @check function roundtrip_string(ex=expr_gen)
2426
tree_str = string_tree(ex)
2527
ex_parsed = parse_expression(
2628
Meta.parse(tree_str);
@@ -30,15 +32,25 @@
3032
)
3133
return ex == ex_parsed
3234
end
35+
@test something(result.result) isa Supposition.Pass
3336

3437
input_gen = make_input_matrix_generator(T; n_features)
35-
@check max_examples = 1024 function eval_against_string(ex=expr_gen, X=input_gen)
36-
expression_result, ok = eval_tree_array(ex, X)
37-
!ok && return true # If the expression is not valid, we can't test it
38+
args_gen = map(
39+
(ex, X, turbo, bumper) -> (; ex, X, turbo, bumper),
40+
expr_gen,
41+
input_gen,
42+
Data.Booleans(),
43+
Data.Booleans(),
44+
)
45+
# We only consider expressions that don't have NaN/Inf/etc.
46+
clean_args_gen = filter(args -> eval_tree_array(args.ex, args.X)[2], args_gen)
47+
result2 = @check max_examples = 1000 function eval_against_string(args=clean_args_gen)
48+
(; ex, X, turbo, bumper) = args
49+
expression_result, ok = eval_tree_array(ex, X; turbo, bumper)
3850
tree_str = string_tree(ex)
39-
f_sym = gensym("f")
4051
f = eval(Meta.parse("(x1, x2, x3, x4, x5) -> ($tree_str)"))
4152
true_result = Float64[Base.invokelatest(f, x...) for x in eachcol(X)]
42-
return expression_result true_result
53+
return ok && expression_result true_result
4354
end
55+
@test something(result2.result) isa Supposition.Pass
4456
end

0 commit comments

Comments
 (0)