Skip to content
2 changes: 2 additions & 0 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ end
before_loss,
options;
parent_ref=parent_ref,
mutation_choice=mutation_choice,
),
mutation_accepted,
num_evals,
Expand All @@ -284,6 +285,7 @@ end
before_loss,
options;
parent_ref=parent_ref,
mutation_choice=mutation_choice,
),
mutation_accepted,
num_evals,
Expand Down
2 changes: 2 additions & 0 deletions src/PopMember.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ function create_child(
loss::L,
options;
complexity::Union{Int,Nothing}=nothing,
mutation_choice::Union{Symbol,Nothing}=nothing,
parent_ref,
) where {T,L,P<:PopMember{T,L}}
actual_complexity = @something complexity compute_complexity(tree, options)
Expand Down Expand Up @@ -241,6 +242,7 @@ function create_child(
loss::L,
options;
complexity::Union{Int,Nothing}=nothing,
mutation_choice::Union{Symbol,Nothing}=nothing,
parent_ref,
) where {T,L,P<:PopMember{T,L}}
actual_complexity = @something complexity compute_complexity(tree, options)
Expand Down
4 changes: 3 additions & 1 deletion src/Population.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ end
# Return best 10 examples
function best_sub_pop(pop::P; topn::Int=10)::P where {P<:Population}
best_idx = sortperm([pop.members[member].cost for member in 1:(pop.n)])
return Population(pop.members[best_idx[1:topn]])
# Ensure we don't try to access more elements than exist in the population
actual_topn = min(topn, pop.n)
return Population(pop.members[best_idx[1:actual_topn]])
end

function record_population(pop::Population, options::AbstractOptions)::RecordType
Expand Down
2 changes: 1 addition & 1 deletion src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ end
datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions
) where {T,L,D<:Dataset{T,L}}
stdin_reader = watch_stream(options.input_stream)

example_dataset = first(datasets)
record = RecordType()
@recorder record["options"] = "$(options)"

Expand Down
69 changes: 67 additions & 2 deletions test/test_abstract_popmember.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
using DispatchDoctor: @unstable

import SymbolicRegression.PopMemberModule: create_child
import SymbolicRegression: strip_metadata

# Define a custom PopMember that tracks generation count
mutable struct CustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N}
mutable struct CustomPopMember{T,L,N<:AbstractExpression{T}} <:
SymbolicRegression.AbstractPopMember{T,L,N}
tree::N
cost::L
loss::L
Expand Down Expand Up @@ -107,6 +109,7 @@
loss::L,
options;
complexity::Union{Int,Nothing}=nothing,
mutation_choice::Union{Symbol,Nothing}=nothing,
parent_ref,
) where {T,L}
actual_complexity = @something complexity SymbolicRegression.compute_complexity(
Expand All @@ -131,6 +134,7 @@
loss::L,
options;
complexity::Union{Int,Nothing}=nothing,
mutation_choice::Union{Symbol,Nothing}=nothing,
parent_ref,
) where {T,L,N<:AbstractExpression{T}}
actual_complexity = @something complexity SymbolicRegression.compute_complexity(
Expand All @@ -151,13 +155,31 @@
)
end

function strip_metadata(
member::CustomPopMember,
options::SymbolicRegression.AbstractOptions,
dataset::SymbolicRegression.Dataset{T,L},
) where {T,L}
complexity = SymbolicRegression.compute_complexity(member.tree, options)
return CustomPopMember(
strip_metadata(member.tree, options, dataset),
member.cost,
member.loss,
SymbolicRegression.get_birth_order(; deterministic=options.deterministic),
complexity,
member.ref,
member.parent,
member.generation,
)
end

# Test that we can run equation_search with CustomPopMember
X = randn(Float32, 2, 100)
y = @. X[1, :]^2 - X[2, :]

options = SymbolicRegression.Options(;
binary_operators=[+, -],
populations=1,
populations=2,
population_size=20,
maxsize=5,
popmember_type=CustomPopMember,
Expand Down Expand Up @@ -189,4 +211,47 @@
@test !isnothing(best_idx)
best_member = hall_of_fame.members[best_idx]
@test best_member isa CustomPopMember

# Test that guesses API returns CustomPopMember instances
guess_X = randn(Float32, 2, 80)
guess_y = @. guess_X[1, :] - guess_X[2, :]
guess_dataset = SymbolicRegression.Dataset(guess_X, guess_y)

guess_options = SymbolicRegression.Options(;
binary_operators=[+, -],
populations=1,
population_size=5,
tournament_selection_n=2,
maxsize=4,
popmember_type=CustomPopMember,
deterministic=true,
seed=1,
verbosity=0,
progress=false,
)

parsed = SymbolicRegression.parse_guesses(
CustomPopMember{Float32,Float32}, ["x1 - x2"], [guess_dataset], guess_options
)

@test length(parsed) == 1
@test length(parsed[1]) == 1
parsed_member = parsed[1][1]
@test parsed_member isa CustomPopMember{Float32,Float32}
@test isapprox(parsed_member.loss, 0.0f0; atol=1.0f-6)

# Confirm equation_search accepts guesses with CustomPopMember
hof_from_guess = equation_search(
guess_X,
guess_y;
options=guess_options,
guesses=["x1 - x2"],
niterations=0,
parallelism=:serial,
)

@test sum(hof_from_guess.exists) > 0
guess_best_idx = findlast(hof_from_guess.exists)
@test !isnothing(guess_best_idx)
@test hof_from_guess.members[guess_best_idx] isa CustomPopMember
end
1 change: 1 addition & 0 deletions test/test_units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ end
end

@testitem "Miscellaneous tests of unit interface" tags = [:part3] begin
using MLJBase
using SymbolicRegression
using DynamicQuantities
using SymbolicRegression.DimensionalAnalysisModule: @maybe_return_call, WildcardQuantity
Expand Down