Skip to content

Commit 6c611f3

Browse files
committed
fix: preserve state when niterations=0
1 parent fae61b3 commit 6c611f3

File tree

3 files changed

+85
-17
lines changed

3 files changed

+85
-17
lines changed

src/SymbolicRegression.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,12 +802,32 @@ function _initialize_search!(
802802
end
803803
return nothing
804804
end
805+
806+
function _preserve_loaded_state!(
807+
state::AbstractSearchState{T,L,N},
808+
ropt::AbstractRuntimeOptions,
809+
options::AbstractOptions,
810+
) where {T,L,N}
811+
nout = length(state.worker_output)
812+
for j in 1:nout, i in 1:(options.populations)
813+
(pop, _, _, _) = extract_from_worker(
814+
state.worker_output[j][i], Population{T,L,N}, HallOfFame{T,L,N}
815+
)
816+
state.last_pops[j][i] = copy(pop)
817+
end
818+
return nothing
819+
end
820+
805821
function _warmup_search!(
806822
state::AbstractSearchState{T,L,N},
807823
datasets,
808824
ropt::AbstractRuntimeOptions,
809825
options::AbstractOptions,
810826
) where {T,L,N}
827+
if ropt.niterations == 0
828+
return _preserve_loaded_state!(state, ropt, options)
829+
end
830+
811831
nout = length(datasets)
812832
for j in 1:nout, i in 1:(options.populations)
813833
dataset = datasets[j]

test/runtests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ end
144144
include("test_deterministic.jl")
145145
end
146146

147-
@testitem "Testing whether early stop criteria works." tags = [:part2] begin
148-
include("test_early_stop.jl")
149-
end
147+
include("test_early_stop.jl")
150148

151149
include("test_mlj.jl")
152150

test/test_early_stop.jl

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,69 @@
1-
using SymbolicRegression
1+
@testitem "Early stop condition" tags = [:part2] begin
2+
using SymbolicRegression
23

3-
X = randn(Float32, 5, 100)
4-
y = 2 * cos.(X[4, :]) + X[1, :] .^ 2
4+
X = randn(Float32, 5, 100)
5+
y = 2 * cos.(X[4, :]) + X[1, :] .^ 2
56

6-
early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10))
7+
early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10))
78

8-
options = SymbolicRegression.Options(;
9-
binary_operators=(+, *, /, -),
10-
unary_operators=(cos, exp),
11-
populations=20,
12-
early_stop_condition=early_stop,
13-
)
9+
options = SymbolicRegression.Options(;
10+
binary_operators=(+, *, /, -),
11+
unary_operators=(cos, exp),
12+
populations=20,
13+
early_stop_condition=early_stop,
14+
)
1415

15-
hof = equation_search(X, y; options=options, niterations=1_000_000_000)
16+
hof = equation_search(X, y; options=options, niterations=1_000_000_000)
1617

17-
@test any(
18-
early_stop(member.loss, count_nodes(member.tree)) for member in hof.members[hof.exists]
19-
)
18+
@test any(
19+
early_stop(member.loss, count_nodes(member.tree)) for
20+
member in hof.members[hof.exists]
21+
)
22+
end
23+
24+
@testitem "State preservation with niterations=0" tags = [:part2] begin
25+
using SymbolicRegression
26+
using Random
27+
28+
# Regression test for https://github.com/MilesCranmer/SymbolicRegression.jl/issues/178
29+
30+
rng = MersenneTwister(42)
31+
X = randn(rng, 2, 10)
32+
y = X[1, :] .+ X[2, :]
33+
34+
options = Options(;
35+
binary_operators=(+,),
36+
unary_operators=(),
37+
verbosity=0,
38+
progress=false,
39+
population_size=5,
40+
populations=2,
41+
maxsize=5,
42+
tournament_selection_n=2,
43+
)
44+
45+
# Manually create saved state
46+
dataset = Dataset(X, y)
47+
pop1 = Population(dataset; population_size=5, nlength=3, options=options, nfeatures=2)
48+
pop2 = Population(dataset; population_size=5, nlength=3, options=options, nfeatures=2)
49+
hof = HallOfFame(options, dataset)
50+
51+
saved_pops = [[pop1, pop2]]
52+
saved_hof = [hof]
53+
saved_state = (saved_pops, saved_hof)
54+
55+
# Run with niterations=0 - should preserve populations
56+
result_pops, result_hof = equation_search(
57+
X,
58+
y;
59+
niterations=0,
60+
saved_state=saved_state,
61+
options=options,
62+
parallelism=:serial,
63+
return_state=true,
64+
)
65+
66+
# Verify populations are preserved (not reset to size 1)
67+
@test length(result_pops[1]) == 2
68+
@test all(pop -> length(pop.members) == 5, result_pops[1])
69+
end

0 commit comments

Comments
 (0)