Skip to content

Commit 5d54b6b

Browse files
authored
Merge pull request MilesCranmer#498 from MilesCranmer/issue178
fix: preserve state when niterations=0
2 parents e8b719f + 6c611f3 commit 5d54b6b

File tree

3 files changed

+85
-17
lines changed

3 files changed

+85
-17
lines changed

src/SymbolicRegression.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,12 +803,32 @@ function _initialize_search!(
803803
end
804804
return nothing
805805
end
806+
807+
function _preserve_loaded_state!(
808+
state::AbstractSearchState{T,L,N},
809+
ropt::AbstractRuntimeOptions,
810+
options::AbstractOptions,
811+
) where {T,L,N}
812+
nout = length(state.worker_output)
813+
for j in 1:nout, i in 1:(options.populations)
814+
(pop, _, _, _) = extract_from_worker(
815+
state.worker_output[j][i], Population{T,L,N}, HallOfFame{T,L,N}
816+
)
817+
state.last_pops[j][i] = copy(pop)
818+
end
819+
return nothing
820+
end
821+
806822
function _warmup_search!(
807823
state::AbstractSearchState{T,L,N},
808824
datasets,
809825
ropt::AbstractRuntimeOptions,
810826
options::AbstractOptions,
811827
) where {T,L,N}
828+
if ropt.niterations == 0
829+
return _preserve_loaded_state!(state, ropt, options)
830+
end
831+
812832
nout = length(datasets)
813833
for j in 1:nout, i in 1:(options.populations)
814834
dataset = datasets[j]

test/runtests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ end
144144
include("test_deterministic.jl")
145145
end
146146

147-
@testitem "Testing whether early stop criteria works." tags = [:part2] begin
148-
include("test_early_stop.jl")
149-
end
147+
include("test_early_stop.jl")
150148

151149
include("test_mlj.jl")
152150

test/test_early_stop.jl

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,69 @@
1-
using SymbolicRegression
1+
@testitem "Early stop condition" tags = [:part2] begin
2+
using SymbolicRegression
23

3-
X = randn(Float32, 5, 100)
4-
y = 2 * cos.(X[4, :]) + X[1, :] .^ 2
4+
X = randn(Float32, 5, 100)
5+
y = 2 * cos.(X[4, :]) + X[1, :] .^ 2
56

6-
early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10))
7+
early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10))
78

8-
options = SymbolicRegression.Options(;
9-
binary_operators=(+, *, /, -),
10-
unary_operators=(cos, exp),
11-
populations=20,
12-
early_stop_condition=early_stop,
13-
)
9+
options = SymbolicRegression.Options(;
10+
binary_operators=(+, *, /, -),
11+
unary_operators=(cos, exp),
12+
populations=20,
13+
early_stop_condition=early_stop,
14+
)
1415

15-
hof = equation_search(X, y; options=options, niterations=1_000_000_000)
16+
hof = equation_search(X, y; options=options, niterations=1_000_000_000)
1617

17-
@test any(
18-
early_stop(member.loss, count_nodes(member.tree)) for member in hof.members[hof.exists]
19-
)
18+
@test any(
19+
early_stop(member.loss, count_nodes(member.tree)) for
20+
member in hof.members[hof.exists]
21+
)
22+
end
23+
24+
@testitem "State preservation with niterations=0" tags = [:part2] begin
25+
using SymbolicRegression
26+
using Random
27+
28+
# Regression test for https://github.com/MilesCranmer/SymbolicRegression.jl/issues/178
29+
30+
rng = MersenneTwister(42)
31+
X = randn(rng, 2, 10)
32+
y = X[1, :] .+ X[2, :]
33+
34+
options = Options(;
35+
binary_operators=(+,),
36+
unary_operators=(),
37+
verbosity=0,
38+
progress=false,
39+
population_size=5,
40+
populations=2,
41+
maxsize=5,
42+
tournament_selection_n=2,
43+
)
44+
45+
# Manually create saved state
46+
dataset = Dataset(X, y)
47+
pop1 = Population(dataset; population_size=5, nlength=3, options=options, nfeatures=2)
48+
pop2 = Population(dataset; population_size=5, nlength=3, options=options, nfeatures=2)
49+
hof = HallOfFame(options, dataset)
50+
51+
saved_pops = [[pop1, pop2]]
52+
saved_hof = [hof]
53+
saved_state = (saved_pops, saved_hof)
54+
55+
# Run with niterations=0 - should preserve populations
56+
result_pops, result_hof = equation_search(
57+
X,
58+
y;
59+
niterations=0,
60+
saved_state=saved_state,
61+
options=options,
62+
parallelism=:serial,
63+
return_state=true,
64+
)
65+
66+
# Verify populations are preserved (not reset to size 1)
67+
@test length(result_pops[1]) == 2
68+
@test all(pop -> length(pop.members) == 5, result_pops[1])
69+
end

0 commit comments

Comments
 (0)