diff --git a/src/Mutate.jl b/src/Mutate.jl index 412d7ea06..fc986df14 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -261,6 +261,7 @@ end before_loss, options; parent_ref=parent_ref, + mutation_choice=mutation_choice, ), mutation_accepted, num_evals, @@ -284,6 +285,7 @@ end before_loss, options; parent_ref=parent_ref, + mutation_choice=mutation_choice, ), mutation_accepted, num_evals, diff --git a/src/PopMember.jl b/src/PopMember.jl index 71f8707de..4343f947e 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -213,6 +213,7 @@ function create_child( loss::L, options; complexity::Union{Int,Nothing}=nothing, + mutation_choice::Union{Symbol,Nothing}=nothing, parent_ref, ) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) @@ -241,6 +242,7 @@ function create_child( loss::L, options; complexity::Union{Int,Nothing}=nothing, + mutation_choice::Union{Symbol,Nothing}=nothing, parent_ref, ) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) diff --git a/src/Population.jl b/src/Population.jl index 00f603258..594b9bc9f 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -212,7 +212,9 @@ end # Return best 10 examples function best_sub_pop(pop::P; topn::Int=10)::P where {P<:Population} best_idx = sortperm([pop.members[member].cost for member in 1:(pop.n)]) - return Population(pop.members[best_idx[1:topn]]) + # Ensure we don't try to access more elements than exist in the population + actual_topn = min(topn, pop.n) + return Population(pop.members[best_idx[1:actual_topn]]) end function record_population(pop::Population, options::AbstractOptions)::RecordType diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 4e9443f23..3de51a402 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -628,7 +628,7 @@ end datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions ) where {T,L,D<:Dataset{T,L}} stdin_reader = watch_stream(options.input_stream) - + example_dataset = first(datasets) record = RecordType() @recorder record["options"] = "$(options)" diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl index eabc6886d..73107c453 100644 --- a/test/test_abstract_popmember.jl +++ b/test/test_abstract_popmember.jl @@ -5,9 +5,11 @@ using DispatchDoctor: @unstable import SymbolicRegression.PopMemberModule: create_child + import SymbolicRegression: strip_metadata # Define a custom PopMember that tracks generation count - mutable struct CustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N} + mutable struct CustomPopMember{T,L,N<:AbstractExpression{T}} <: + SymbolicRegression.AbstractPopMember{T,L,N} tree::N cost::L loss::L @@ -107,6 +109,7 @@ loss::L, options; complexity::Union{Int,Nothing}=nothing, + mutation_choice::Union{Symbol,Nothing}=nothing, parent_ref, ) where {T,L} actual_complexity = @something complexity SymbolicRegression.compute_complexity( @@ -131,6 +134,7 @@ loss::L, options; complexity::Union{Int,Nothing}=nothing, + mutation_choice::Union{Symbol,Nothing}=nothing, parent_ref, ) where {T,L,N<:AbstractExpression{T}} actual_complexity = @something complexity SymbolicRegression.compute_complexity( @@ -151,13 +155,31 @@ ) end + function strip_metadata( + member::CustomPopMember, + options::SymbolicRegression.AbstractOptions, + dataset::SymbolicRegression.Dataset{T,L}, + ) where {T,L} + complexity = SymbolicRegression.compute_complexity(member.tree, options) + return CustomPopMember( + strip_metadata(member.tree, options, dataset), + member.cost, + member.loss, + SymbolicRegression.get_birth_order(; deterministic=options.deterministic), + complexity, + member.ref, + member.parent, + member.generation, + ) + end + # Test that we can run equation_search with CustomPopMember X = randn(Float32, 2, 100) y = @. X[1, :]^2 - X[2, :] options = SymbolicRegression.Options(; binary_operators=[+, -], - populations=1, + populations=2, population_size=20, maxsize=5, popmember_type=CustomPopMember, @@ -189,4 +211,47 @@ @test !isnothing(best_idx) best_member = hall_of_fame.members[best_idx] @test best_member isa CustomPopMember + + # Test that guesses API returns CustomPopMember instances + guess_X = randn(Float32, 2, 80) + guess_y = @. guess_X[1, :] - guess_X[2, :] + guess_dataset = SymbolicRegression.Dataset(guess_X, guess_y) + + guess_options = SymbolicRegression.Options(; + binary_operators=[+, -], + populations=1, + population_size=5, + tournament_selection_n=2, + maxsize=4, + popmember_type=CustomPopMember, + deterministic=true, + seed=1, + verbosity=0, + progress=false, + ) + + parsed = SymbolicRegression.parse_guesses( + CustomPopMember{Float32,Float32}, ["x1 - x2"], [guess_dataset], guess_options + ) + + @test length(parsed) == 1 + @test length(parsed[1]) == 1 + parsed_member = parsed[1][1] + @test parsed_member isa CustomPopMember{Float32,Float32} + @test isapprox(parsed_member.loss, 0.0f0; atol=1.0f-6) + + # Confirm equation_search accepts guesses with CustomPopMember + hof_from_guess = equation_search( + guess_X, + guess_y; + options=guess_options, + guesses=["x1 - x2"], + niterations=0, + parallelism=:serial, + ) + + @test sum(hof_from_guess.exists) > 0 + guess_best_idx = findlast(hof_from_guess.exists) + @test !isnothing(guess_best_idx) + @test hof_from_guess.members[guess_best_idx] isa CustomPopMember end diff --git a/test/test_units.jl b/test/test_units.jl index dd36ca6c9..676ca72a6 100644 --- a/test/test_units.jl +++ b/test/test_units.jl @@ -430,6 +430,7 @@ end end @testitem "Miscellaneous tests of unit interface" tags = [:part3] begin + using MLJBase using SymbolicRegression using DynamicQuantities using SymbolicRegression.DimensionalAnalysisModule: @maybe_return_call, WildcardQuantity