Skip to content

Commit 230279d

Browse files
committed
test the guesses_api with a custom pop member
1 parent 2ee896c commit 230279d

File tree

1 file changed

+64
-1
lines changed

1 file changed

+64
-1
lines changed

test/test_abstract_popmember.jl

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
using DispatchDoctor: @unstable
66

77
import SymbolicRegression.PopMemberModule: create_child
8+
import SymbolicRegression: strip_metadata
89

910
# Define a custom PopMember that tracks generation count
10-
mutable struct CustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N}
11+
mutable struct CustomPopMember{T,L,N<:AbstractExpression{T}} <:
12+
SymbolicRegression.AbstractPopMember{T,L,N}
1113
tree::N
1214
cost::L
1315
loss::L
@@ -153,6 +155,24 @@
153155
)
154156
end
155157

158+
function strip_metadata(
159+
member::CustomPopMember,
160+
options::SymbolicRegression.AbstractOptions,
161+
dataset::SymbolicRegression.Dataset{T,L},
162+
) where {T,L}
163+
complexity = SymbolicRegression.compute_complexity(member.tree, options)
164+
return CustomPopMember(
165+
strip_metadata(member.tree, options, dataset),
166+
member.cost,
167+
member.loss,
168+
SymbolicRegression.get_birth_order(; deterministic=options.deterministic),
169+
complexity,
170+
member.ref,
171+
member.parent,
172+
member.generation,
173+
)
174+
end
175+
156176
# Test that we can run equation_search with CustomPopMember
157177
X = randn(Float32, 2, 100)
158178
y = @. X[1, :]^2 - X[2, :]
@@ -191,4 +211,47 @@
191211
@test !isnothing(best_idx)
192212
best_member = hall_of_fame.members[best_idx]
193213
@test best_member isa CustomPopMember
214+
215+
# Test that guesses API returns CustomPopMember instances
216+
guess_X = randn(Float32, 2, 80)
217+
guess_y = @. guess_X[1, :] - guess_X[2, :]
218+
guess_dataset = SymbolicRegression.Dataset(guess_X, guess_y)
219+
220+
guess_options = SymbolicRegression.Options(;
221+
binary_operators=[+, -],
222+
populations=1,
223+
population_size=5,
224+
tournament_selection_n=2,
225+
maxsize=4,
226+
popmember_type=CustomPopMember,
227+
deterministic=true,
228+
seed=1,
229+
verbosity=0,
230+
progress=false,
231+
)
232+
233+
parsed = SymbolicRegression.parse_guesses(
234+
CustomPopMember{Float32,Float32}, ["x1 - x2"], [guess_dataset], guess_options
235+
)
236+
237+
@test length(parsed) == 1
238+
@test length(parsed[1]) == 1
239+
parsed_member = parsed[1][1]
240+
@test parsed_member isa CustomPopMember{Float32,Float32}
241+
@test isapprox(parsed_member.loss, 0.0f0; atol=1.0f-6)
242+
243+
# Confirm equation_search accepts guesses with CustomPopMember
244+
hof_from_guess = equation_search(
245+
guess_X,
246+
guess_y;
247+
options=guess_options,
248+
guesses=["x1 - x2"],
249+
niterations=0,
250+
parallelism=:serial,
251+
)
252+
253+
@test sum(hof_from_guess.exists) > 0
254+
guess_best_idx = findlast(hof_from_guess.exists)
255+
@test !isnothing(guess_best_idx)
256+
@test hof_from_guess.members[guess_best_idx] isa CustomPopMember
194257
end

0 commit comments

Comments
 (0)