Skip to content

Commit 46d5a1d

Browse files
Add SymbolicRegression integration test (#2965)
* Add SymbolicRegression integration test * Relax compat bounds for SymbolicRegression integration test * Rewrite SymbolicRegression integration test * Simplify SymbolicRegression integration test deps * Remove unnecessary Dataset extra field in SR integration test
1 parent 7944dec commit 46d5a1d

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

.github/workflows/Integration.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ jobs:
5050
- DifferentiationInterface
5151
- Distributions
5252
- DynamicExpressions
53+
- SymbolicRegression
5354
- Lux
5455
- SciML
5556
- KernelAbstractions
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[deps]
2+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
4+
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
5+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6+
7+
[sources]
8+
Enzyme = {path = "../../.."}
9+
EnzymeCore = {path = "../../../lib/EnzymeCore"}
10+
11+
[compat]
12+
SymbolicRegression = "1.12.0"
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using Test
2+
using Random: MersenneTwister
3+
using SymbolicRegression
4+
using Enzyme
5+
6+
rng = MersenneTwister(0)
7+
X = rand(rng, 2, 16)
8+
# Choose a constant target so constant-optimization is guaranteed to run,
9+
# exercising `autodiff_backend=:Enzyme` deterministically.
10+
y = fill(1.0, size(X, 2))
11+
12+
dataset = Dataset(
13+
X,
14+
y;
15+
variable_names=["x1", "x2"],
16+
)
17+
18+
options = Options(
19+
binary_operators=[+, *, -],
20+
unary_operators=[],
21+
populations=1,
22+
population_size=20,
23+
ncycles_per_iteration=5,
24+
maxsize=8,
25+
autodiff_backend=:Enzyme,
26+
optimizer_probability=1.0,
27+
seed=0,
28+
deterministic=true,
29+
verbosity=0,
30+
save_to_file=false,
31+
)
32+
33+
hall_of_fame = equation_search(
34+
dataset;
35+
niterations=2,
36+
options=options,
37+
parallelism=:serial,
38+
runtests=false,
39+
progress=false,
40+
)
41+
42+
best_loss = minimum(
43+
member.loss for (member, exists) in zip(hall_of_fame.members, hall_of_fame.exists) if
44+
exists
45+
)
46+
47+
@test best_loss < 1e-8

0 commit comments

Comments
 (0)