Skip to content

Commit fb248de

Browse files
committed
test: integrate Supposition testing
1 parent 422b5c5 commit fb248de

File tree

4 files changed

+112
-0
lines changed

4 files changed

+112
-0
lines changed

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1717
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1818
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1919
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
20+
Supposition = "5a0628fe-1738-4658-9b6d-0b7605a9755b"
2021
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
2122
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2223
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/supposition_utils.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# supposition_utils.jl
2+
#
3+
# Helper that builds a Supposition generator returning fully-random
4+
# DynamicExpressions.Expression objects whose node type is Node{T,D}.
5+
# D is inferred from `operators`.
6+
7+
module SuppositionUtils
8+
9+
using Supposition: Data
10+
using DynamicExpressions: Node, Expression, OperatorEnum
11+
using DynamicExpressions.OperatorEnumConstructionModule: empty_all_globals!
12+
empty_all_globals!()
13+
14+
function make_expression_generator(
15+
::Type{T};
16+
num_features::Int=5,
17+
operators::OperatorEnum=OperatorEnum(((abs, cos), (+, -, *, /))),
18+
max_layers::Int=3,
19+
) where {T}
20+
D = length(operators.ops)
21+
22+
val_gen = Data.Floats{T}(; nans=false, infs=false)
23+
val_node_gen = map(v -> Node{T,D}(; val=v), val_gen)
24+
25+
feature_gen = Data.SampledFrom(1:num_features)
26+
feature_node_gen = map(i -> Node{T,D}(; feature=i), feature_gen)
27+
28+
leaf_gen = val_node_gen | feature_node_gen
29+
30+
wrapper_funcs = ntuple(
31+
degree -> let op_list = operators[degree]
32+
op_gen = Data.SampledFrom(1:length(op_list))
33+
34+
child -> map(
35+
(op_idx, args...) -> Node{T,D}(; op=op_idx, children=args),
36+
op_gen,
37+
ntuple(_ -> child, degree)...,
38+
)
39+
end,
40+
Val(D),
41+
)
42+
expr_wrap(child) = foldl(|, (w(child) for w in wrapper_funcs))
43+
tree_gen = Data.Recursive(leaf_gen, expr_wrap; max_layers)
44+
return map(
45+
t -> Expression(t; operators, variable_names=["x$i" for i in 1:num_features]),
46+
tree_gen,
47+
)
48+
end
49+
50+
# inside module SuppositionUtils
51+
function make_input_matrix_generator(
52+
::Type{T}=Float64; n_features::Int=5, min_batch::Int=1, max_batch::Int=16
53+
) where {T}
54+
elem_gen = Data.Floats{T}(; nans=false, infs=false)
55+
batch_gen = Data.Integers(min_batch, max_batch)
56+
57+
Data.bind(batch_gen) do bs
58+
vec_len = n_features * bs
59+
vec_gen = Data.Vectors(elem_gen; min_size=vec_len, max_size=vec_len)
60+
Data.map(v -> reshape(v, n_features, bs), vec_gen)
61+
end
62+
end
63+
64+
end
65+
66+
using .SuppositionUtils: make_input_matrix_generator, make_expression_generator
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
@testitem "Supposition round-trip consistency" begin
2+
using Test
3+
using Random
4+
5+
using Supposition
6+
using Supposition: @check, Data
7+
using DynamicExpressions
8+
using DynamicExpressions:
9+
string_tree, parse_expression, eval_tree_array, Node, get_operators, get_tree
10+
11+
# bring the generator into scope
12+
include("supposition_utils.jl")
13+
14+
n_features = 5
15+
max_layers = 20
16+
T = Float64
17+
operators = OperatorEnum(((abs, cos, exp), (+, -, *, /), (fma, clamp, +, max)))
18+
19+
expr_gen = make_expression_generator(
20+
T; num_features=n_features, max_layers=max_layers, operators=operators
21+
)
22+
23+
@check function roundtrip_string(ex=expr_gen)
24+
tree_str = string_tree(ex)
25+
ex_parsed = parse_expression(
26+
Meta.parse(tree_str);
27+
operators=get_operators(ex),
28+
variable_names=["x$i" for i in 1:n_features],
29+
node_type=Node{Float64,3},
30+
)
31+
return ex == ex_parsed
32+
end
33+
34+
input_gen = make_input_matrix_generator(T; n_features)
35+
@check max_examples = 1024 function eval_against_string(ex=expr_gen, X=input_gen)
36+
expression_result, ok = eval_tree_array(ex, X)
37+
!ok && return true # If the expression is not valid, we can't test it
38+
tree_str = string_tree(ex)
39+
f_sym = gensym("f")
40+
f = eval(Meta.parse("(x1, x2, x3, x4, x5) -> ($tree_str)"))
41+
true_result = Float64[Base.invokelatest(f, x...) for x in eachcol(X)]
42+
return expression_result true_result
43+
end
44+
end

test/unittest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,4 @@ include("test_expression_math.jl")
133133
include("test_structured_expression.jl")
134134
include("test_readonlynode.jl")
135135
include("test_zygote_gradient_wrapper.jl")
136+
include("test_supposition_consistency.jl")

0 commit comments

Comments
 (0)