From fbb2a4ff223e1c8bb7f2d07f47ca652fadd0fd6e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 00:39:41 +0100 Subject: [PATCH 01/13] feat: create AbstractPopMember for customizing PopMember --- src/ConstantOptimization.jl | 8 ++--- src/HallOfFame.jl | 31 ++++++++++++-------- src/MLJInterface.jl | 3 +- src/Migration.jl | 4 +-- src/Mutate.jl | 17 ++++++----- src/Options.jl | 6 ++++ src/OptionsStruct.jl | 2 ++ src/PopMember.jl | 40 ++++++++++++++++++++----- src/Population.jl | 58 ++++++++++++++++++++++++------------- src/SymbolicRegression.jl | 24 ++++++++++----- 10 files changed, 132 insertions(+), 61 deletions(-) diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl index 585a6ef78..7fa7798d8 100644 --- a/src/ConstantOptimization.jl +++ b/src/ConstantOptimization.jl @@ -17,7 +17,7 @@ using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options, dataset_fraction using ..UtilsModule: get_birth_order, PerTaskCache, stable_get! using ..LossFunctionsModule: eval_loss, loss_to_cost -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember function can_optimize(::AbstractExpression{T}, options) where {T} return can_optimize(T, options) @@ -31,7 +31,7 @@ end member::P, options::AbstractOptions; rng::AbstractRNG=default_rng(), -)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,N,P<:AbstractPopMember{T,L,N}} can_optimize(member.tree, options) || return (member, 0.0) nconst = count_constants_for_optimization(member.tree) nconst == 0 && return (member, 0.0) @@ -63,7 +63,7 @@ count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex) function _optimize_constants( dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T,L,N,P<:AbstractPopMember{T,L,N}} tree = member.tree x0, refs = get_scalar_constants(tree) @assert count_constants_for_optimization(tree) == length(x0) @@ -76,7 +76,7 @@ function _optimize_constants( end function _optimize_constants_inner( f::F, fg!::G, x0, refs, dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {F,G,T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {F,G,T,L,N,P<:AbstractPopMember{T,L,N}} obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing f else diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index c90fbe3a4..4a1b7841e 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -6,12 +6,13 @@ using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression, init_value using ..ComplexityModule: compute_complexity -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember +import ..PopMemberModule: popmember_type using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING using Printf: @sprintf """ - HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} + HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N}} List of the best members seen all time in `.members`, with `.members[c]` being the best member seen at complexity c. Including only the members which actually @@ -19,15 +20,19 @@ have been set, you can run `.members[exists]`. # Fields -- `members::Array{PopMember{T,L,N},1}`: List of the best members seen all time. +- `members::Array{PM,1}`: List of the best members seen all time. These are ordered by complexity, with `.members[1]` the member with complexity 1. - `exists::Array{Bool,1}`: Whether the member at the given complexity has been set. """ -struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct HallOfFame{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} exists::Array{Bool,1} #Whether it has been set end -function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N} +function Base.show( + io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N,PM} +) where {T,L,N,PM} println(io, "HallOfFame{...}:") for i in eachindex(hof.members, hof.exists) s_member, s_exists = if hof.exists[i] @@ -47,8 +52,8 @@ function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where end return nothing end -function Base.eltype(::Union{HOF,Type{HOF}}) where {T,L,N,HOF<:HallOfFame{T,L,N}} - return PopMember{T,L,N} +function Base.eltype(::Union{HOF,Type{HOF}}) where {T,L,N,PM,HOF<:HallOfFame{T,L,N,PM}} + return PM end """ @@ -69,7 +74,7 @@ function HallOfFame( ) where {T<:DATA_TYPE,L<:LOSS_TYPE} base_tree = create_expression(init_value(T), options, dataset) - return HallOfFame{T,L,typeof(base_tree)}( + return HallOfFame{T,L,typeof(base_tree),PopMember{T,L,typeof(base_tree)}}( [ PopMember( copy(base_tree), @@ -93,11 +98,10 @@ end """ calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,P}) where {T<:DATA_TYPE,L<:LOSS_TYPE} """ -function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} +function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N,PM}) where {T,L,N,PM} # TODO - remove dataset from args. - P = PopMember{T,L,N} # Dominating pareto curve - must be better than all simpler equations - dominating = P[] + dominating = PM[] for size in eachindex(hallOfFame.members) if !hallOfFame.exists[size] continue @@ -276,4 +280,7 @@ function format_hall_of_fame(hof::AbstractVector{<:HallOfFame}, options) end # TODO: Re-use this in `string_dominating_pareto_curve` +# Type accessor for HallOfFame +popmember_type(::Type{<:HallOfFame{T,L,N,PM}}) where {T,L,N,PM} = PM + end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 5845e2db2..b441fda2d 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -39,7 +39,8 @@ using ..CoreModule: ExpressionSpec, get_expression_type, check_warm_start_compatibility -using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS +using ..CoreModule.OptionsModule: + DEFAULT_OPTIONS, OPTION_DESCRIPTIONS, default_popmember_type using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore diff --git a/src/Migration.jl b/src/Migration.jl index f7fe61b89..3988b81ea 100644 --- a/src/Migration.jl +++ b/src/Migration.jl @@ -2,7 +2,7 @@ module MigrationModule using ..CoreModule: AbstractOptions using ..PopulationModule: Population -using ..PopMemberModule: PopMember, reset_birth! +using ..PopMemberModule: AbstractPopMember, PopMember, reset_birth! using ..UtilsModule: poisson_sample """ @@ -14,7 +14,7 @@ to do so. The original migrant population is not modified. Pass with, e.g., """ function migrate!( migration::Pair{Vector{PM},P}, options::AbstractOptions; frac::AbstractFloat -) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}} +) where {T,L,N,PM<:AbstractPopMember{T,L,N},P<:Population{T,L,N,PM}} base_pop = migration.second population_size = length(base_pop.members) mean_number_replaced = population_size * frac diff --git a/src/Mutate.jl b/src/Mutate.jl index f5fd88457..996f474a1 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -22,7 +22,7 @@ using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember using ..MutationFunctionsModule: mutate_constant, mutate_operator, @@ -61,7 +61,8 @@ This struct encapsulates the result of a mutation operation. Either a new expres Return the `member` if you want to return immediately, and have computed the loss value as part of the mutation. """ -struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationResult{N,P} +struct MutationResult{N<:AbstractExpression,P<:AbstractPopMember} <: + AbstractMutationResult{N,P} tree::Union{N,Nothing} member::Union{P,Nothing} num_evals::Float64 @@ -73,7 +74,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes member::Union{_P,Nothing}=nothing, num_evals::Float64=0.0, return_immediately::Bool=false, - ) where {_N<:AbstractExpression,_P<:PopMember} + ) where {_N<:AbstractExpression,_P<:AbstractPopMember} @assert( (tree === nothing) ⊻ (member === nothing), "Mutation result must return either a tree or a pop member, not both" @@ -83,7 +84,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes end """ - condition_mutation_weights!(weights::AbstractMutationWeights, member::PopMember, options::AbstractOptions, curmaxsize::Int, nfeatures::Int) + condition_mutation_weights!(weights::AbstractMutationWeights, member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, nfeatures::Int) Adjusts the mutation weights based on the properties of the current member and options. @@ -93,7 +94,7 @@ Note that the weights were already copied, so you don't need to worry about muta # Arguments - `weights::AbstractMutationWeights`: The mutation weights to be adjusted. -- `member::PopMember`: The current population member being mutated. +- `member::AbstractPopMember`: The current population member being mutated. - `options::AbstractOptions`: The options that guide the mutation process. - `curmaxsize::Int`: The current maximum size constraint for the member's expression tree. - `nfeatures::Int`: The number of features available in the dataset. @@ -104,7 +105,7 @@ function condition_mutation_weights!( options::AbstractOptions, curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:AbstractExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:AbstractExpression,P<:AbstractPopMember{T,L,N}} tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 @@ -181,7 +182,7 @@ end tmp_recorder::RecordType, )::Tuple{ P,Bool,Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} parent_ref = member.ref num_evals = 0.0 @@ -665,7 +666,7 @@ function crossover_generation( curmaxsize::Int, options::AbstractOptions; recorder::RecordType=RecordType(), -)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} +)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:AbstractPopMember{T,L,N}} tree1 = member1.tree tree2 = member2.tree crossover_accepted = false diff --git a/src/Options.jl b/src/Options.jl index cde0d5a31..c746b1332 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -40,6 +40,9 @@ using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutatio import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization using ..UtilsModule: @save_kwargs, @ignore + +# Forward declaration - will be defined in PopMemberModule +function default_popmember_type end using ..ExpressionSpecModule: AbstractExpressionSpec, ExpressionSpec, @@ -651,6 +654,7 @@ $(OPTION_DESCRIPTIONS) terminal_width::Union{Nothing,Integer}=nothing, use_recorder::Bool=false, recorder_file::AbstractString="pysr_recorder.json", + popmember_type::Type=default_popmember_type(), ### Not search options; just construction options: define_helper_functions::Bool=true, ######################################### @@ -1030,6 +1034,7 @@ $(OPTION_DESCRIPTIONS) expression_type, typeof(expression_options), typeof(set_mutation_weights), + popmember_type, turbo, bumper, deprecated_return_state::Union{Bool,Nothing}, @@ -1103,6 +1108,7 @@ $(OPTION_DESCRIPTIONS) deterministic, define_helper_functions, use_recorder, + popmember_type, ) return options diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 4cf2ffb9e..6f83f89b0 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -181,6 +181,7 @@ struct Options{ E<:AbstractExpression, EO<:NamedTuple, MW<:AbstractMutationWeights, + PM, _turbo, _bumper, _return_state, @@ -254,6 +255,7 @@ struct Options{ deterministic::Bool define_helper_functions::Bool use_recorder::Bool + popmember_type::Type{PM} end function Base.print(io::IO, @nospecialize(options::Options)) diff --git a/src/PopMember.jl b/src/PopMember.jl index bd195a6c2..1f23ff7e8 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -7,8 +7,25 @@ import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_cost +""" + AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} + +Abstract type for population members. Defines the interface that all population members must implement. + +# Required fields (accessed via getproperty/setproperty!) +- `tree::N`: The expression tree +- `cost::L`: The cost including complexity penalty and normalization +- `loss::L`: The raw loss value +- `birth::Int`: Birth order/generation number +- `ref::Int`: Unique reference ID +- `parent::Int`: Parent reference ID +- `complexity::Int`: Cached complexity (accessed via getfield/setfield! for special handling) +""" +abstract type AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} end + # Define a member of population by equation, cost, and age -mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} +mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} <: + AbstractPopMember{T,L,N} tree::N cost::L # Inludes complexity penalty, normalization loss::L # Raw loss @@ -19,7 +36,9 @@ mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} ref::Int parent::Int end -@inline function Base.setproperty!(member::PopMember, field::Symbol, value) + +# Generic interface implementations for AbstractPopMember +@inline function Base.setproperty!(member::AbstractPopMember, field::Symbol, value) if field == :complexity throw( error("Don't set `.complexity` directly. Use `recompute_complexity!` instead.") @@ -34,7 +53,7 @@ end end return setfield!(member, field, value) end -@unstable @inline function Base.getproperty(member::PopMember, field::Symbol) +@unstable @inline function Base.getproperty(member::AbstractPopMember, field::Symbol) if field == :complexity throw( error("Don't access `.complexity` directly. Use `compute_complexity` instead.") @@ -145,7 +164,7 @@ function PopMember( ) end -function Base.copy(p::P) where {P<:PopMember} +function Base.copy(p::P) where {P<:AbstractPopMember} tree = copy(p.tree) cost = copy(p.cost) loss = copy(p.loss) @@ -156,14 +175,14 @@ function Base.copy(p::P) where {P<:PopMember} return P(tree, cost, loss, birth, complexity, ref, parent) end -function reset_birth!(p::PopMember; deterministic::Bool) +function reset_birth!(p::AbstractPopMember; deterministic::Bool) p.birth = get_birth_order(; deterministic) return p end # Can read off complexity directly from pop members function compute_complexity( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) + member::AbstractPopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = getfield(member, :complexity) complexity == -1 && return recompute_complexity!(member, options; break_sharing) @@ -171,11 +190,18 @@ function compute_complexity( return complexity end function recompute_complexity!( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) + member::AbstractPopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = compute_complexity(member.tree, options; break_sharing) setfield!(member, :complexity, complexity) return complexity end +# Function to extract PopMember type from Population or HallOfFame types +function popmember_type end + +# Default PopMember type for Options +import ..CoreModule.OptionsModule: default_popmember_type +default_popmember_type() = PopMember + end diff --git a/src/Population.jl b/src/Population.jl index 739ca828e..d3bd2b517 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -3,25 +3,29 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, string_tree +using ConstructionBase: constructorof using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutationFunctionsModule: gen_random_tree -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember +import ..PopMemberModule: popmember_type using ..UtilsModule: bottomk_fast, argmin_fast, PerTaskCache # A list of members of the population, with easy constructors, # which allow for random generation of new populations -struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct Population{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} n::Int end """ - Population(pop::Array{PopMember{T,L}, 1}) + Population(pop::Array{<:AbstractPopMember, 1}) Create population from list of PopMembers. """ -function Population(pop::Vector{<:PopMember}) +function Population(pop::Vector{<:AbstractPopMember}) return Population(pop, size(pop, 1)) end @@ -41,23 +45,34 @@ function Population( npop=nothing, ) where {T,L} @assert (population_size !== nothing) ⊻ (npop !== nothing) - population_size = if npop === nothing - population_size - else - npop - end - return Population( - [ - PopMember( + population_size = something(population_size, npop) + PM = options.popmember_type + + # Create first member to get concrete type + first_member = constructorof(PM)( + dataset, + gen_random_tree(nlength, options, nfeatures, T), + options; + parent=-1, + deterministic=options.deterministic, + ) + + # Use the concrete type for the array + members = typeof(first_member)[ + if i == 1 + first_member + else + constructorof(PM)( dataset, gen_random_tree(nlength, options, nfeatures, T), options; parent=-1, deterministic=options.deterministic, - ) for _ in 1:population_size - ], - population_size, - ) + ) + end for i in 1:population_size + ] + + return Population(members, population_size) end """ Population(X::AbstractMatrix{T}, y::AbstractVector{T}; @@ -90,8 +105,8 @@ Create random population and score them on the dataset. ) end -function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}} - copied_members = Vector{PopMember{T,L,N}}(undef, pop.n) +function Base.copy(pop::P)::P where {T,L,N,PM,P<:Population{T,L,N,PM}} + copied_members = Vector{PM}(undef, pop.n) Threads.@threads for i in 1:(pop.n) copied_members[i] = copy(pop.members[i]) end @@ -118,7 +133,7 @@ function _best_of_sample( members::Vector{P}, running_search_statistics::RunningSearchStatistics, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L}} +) where {T,L,N,P<:AbstractPopMember{T,L,N}} p = options.tournament_selection_p n = length(members) # == tournament_selection_n adjusted_costs = Vector{L}(undef, n) @@ -218,4 +233,7 @@ function record_population(pop::Population, options::AbstractOptions)::RecordTyp ) end +# Type accessor for Population +popmember_type(::Type{<:Population{T,L,N,PM}}) where {T,L,N,PM} = PM + end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index de709ddde..88c791449 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -2,6 +2,7 @@ module SymbolicRegression # Types export Population, + AbstractPopMember, PopMember, HallOfFame, Options, @@ -297,7 +298,7 @@ using .MutationFunctionsModule: using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func -using .PopMemberModule: PopMember, reset_birth! +using .PopMemberModule: AbstractPopMember, PopMember, reset_birth!, popmember_type using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve @@ -810,10 +811,14 @@ function _preserve_loaded_state!( options::AbstractOptions, ) where {T,L,N} nout = length(state.worker_output) + # Get the prototype to extract types + prototype_pop = state.last_pops[1][1] + PopType = typeof(prototype_pop) + PM = popmember_type(PopType) + HallType = HallOfFame{T,L,N,PM} + for j in 1:nout, i in 1:(options.populations) - (pop, _, _, _) = extract_from_worker( - state.worker_output[j][i], Population{T,L,N}, HallOfFame{T,L,N} - ) + (pop, _, _, _) = extract_from_worker(state.worker_output[j][i], PopType, HallType) state.last_pops[j][i] = copy(pop) end return nothing @@ -843,11 +848,16 @@ function _warmup_search!( # Multi-threaded doesn't like to fetch within a new task: c_rss = deepcopy(running_search_statistics) last_pop = state.worker_output[j][i] + + # Get the prototype to extract types + prototype_pop = state.last_pops[j][i] + PopType = typeof(prototype_pop) + PM = popmember_type(PopType) + HallType = HallOfFame{T,L,N,PM} + updated_pop = @sr_spawner( begin - in_pop = first( - extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N}) - ) + in_pop = first(extract_from_worker(last_pop, PopType, HallType)) _dispatch_s_r_cycle( in_pop, dataset, From d37dc4505ef6d28e815f64973ba41737ce4dd256 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 01:37:21 +0100 Subject: [PATCH 02/13] refactor: no Base.copy for generic AbstractPopMember --- src/PopMember.jl | 4 ++-- src/SymbolicRegression.jl | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index 1f23ff7e8..7c1092dce 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -164,7 +164,7 @@ function PopMember( ) end -function Base.copy(p::P) where {P<:AbstractPopMember} +function Base.copy(p::PopMember) tree = copy(p.tree) cost = copy(p.cost) loss = copy(p.loss) @@ -172,7 +172,7 @@ function Base.copy(p::P) where {P<:AbstractPopMember} complexity = copy(getfield(p, :complexity)) ref = copy(p.ref) parent = copy(p.parent) - return P(tree, cost, loss, birth, complexity, ref, parent) + return PopMember(tree, cost, loss, birth, complexity, ref, parent) end function reset_birth!(p::AbstractPopMember; deterministic::Bool) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 88c791449..acef3e7fb 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -2,7 +2,6 @@ module SymbolicRegression # Types export Population, - AbstractPopMember, PopMember, HallOfFame, Options, From b3cc71c258a44ac88e117906897aa5fc8e6ea5a9 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 02:08:49 +0100 Subject: [PATCH 03/13] refactor: better interface for creating new member --- src/Mutate.jl | 93 +++++++++++++++++++---------------------------- src/PopMember.jl | 71 ++++++++++++++++++++++++++++++++++++ src/Population.jl | 3 +- 3 files changed, 109 insertions(+), 58 deletions(-) diff --git a/src/Mutate.jl b/src/Mutate.jl index 996f474a1..03772f9ca 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -22,7 +22,7 @@ using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..PopMemberModule: AbstractPopMember, PopMember +using ..PopMemberModule: AbstractPopMember, PopMember, create_child using ..MutationFunctionsModule: mutate_constant, mutate_operator, @@ -254,14 +254,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -278,14 +277,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -322,14 +320,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -340,19 +337,16 @@ end tmp_recorder["reason"] = "pass" end mutation_accepted = true - return ( - PopMember( - tree, - after_cost, - after_loss, - options, - newSize; - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, + new_member = create_child( + member, + tree, + after_cost, + after_loss, + options; + complexity=newSize, + parent_ref=parent_ref, ) + return (new_member, mutation_accepted, num_evals) end end @@ -583,17 +577,10 @@ function mutate!( simplify_tree!(tree, options.operators) tree = combine_operators(tree, options.operators) @recorder recorder["type"] = "simplify" - return MutationResult{N,P}(; - member=PopMember( - tree, - member.cost, - member.loss, - options; - parent=parent_ref, - deterministic=options.deterministic, - ), - return_immediately=true, + new_member = create_child( + member, tree, member.cost, member.loss, options; parent_ref=parent_ref ) + return MutationResult{N,P}(; member=new_member, return_immediately=true) end function mutate!( @@ -645,14 +632,8 @@ function mutate!( recorder["reason"] = "identity" end return MutationResult{N,P}(; - member=PopMember( - tree, - member.cost, - member.loss, - options, - compute_complexity(tree, options); - parent=parent_ref, - deterministic=options.deterministic, + member=create_child( + member, tree, member.cost, member.loss, options; parent_ref=parent_ref ), return_immediately=true, ) @@ -705,23 +686,23 @@ function crossover_generation( ) num_evals += 2 * dataset_fraction(dataset) - baby1 = PopMember( + baby1 = create_child( + (member1, member2), child_tree1::AbstractExpression, after_cost1, after_loss1, - options, - afterSize1; - parent=member1.ref, - deterministic=options.deterministic, + options; + complexity=afterSize1, + parent_ref=member1.ref, )::P - baby2 = PopMember( + baby2 = create_child( + (member1, member2), child_tree2::AbstractExpression, after_cost2, after_loss2, - options, - afterSize2; - parent=member2.ref, - deterministic=options.deterministic, + options; + complexity=afterSize2, + parent_ref=member2.ref, )::P @recorder begin diff --git a/src/PopMember.jl b/src/PopMember.jl index 7c1092dce..b01d6dfb9 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -2,6 +2,7 @@ module PopMemberModule using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree +import DynamicExpressions: constructorof using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order @@ -197,6 +198,74 @@ function recompute_complexity!( return complexity end +# Interface for creating child members with custom field preservation +""" + create_child(parent::P, tree, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember + +Create a new PopMember derived from a parent (mutation case). +Custom types should override to preserve their additional fields. + +# Arguments +- `parent`: The parent member being mutated +- `tree`: The new expression tree +- `cost`: The new cost +- `loss`: The new loss +- `options`: The options +- `complexity`: The complexity (computed if not provided) +- `parent_ref`: Reference to parent for tracking +""" +function create_child( + parent::P, + tree::N, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + kwargs..., +) where {T,L,N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} + actual_complexity = @something complexity compute_complexity(tree, options) + return constructorof(P)( + tree, + cost, + loss, + options, + actual_complexity; + parent=parent_ref, + deterministic=options.deterministic, + ) +end + +""" + create_child(parents::Tuple{P,P}, tree, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember + +Create a new PopMember from two parents (crossover case). +Custom types should override to blend their additional fields. +""" +function create_child( + parents::Tuple{P,P}, + tree::N, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + kwargs..., +) where {T,L,N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} + actual_complexity = @something complexity compute_complexity(tree, options) + return constructorof(P)( + tree, + cost, + loss, + options, + actual_complexity; + parent=parent_ref, + deterministic=options.deterministic, + ) +end + # Function to extract PopMember type from Population or HallOfFame types function popmember_type end @@ -204,4 +273,6 @@ function popmember_type end import ..CoreModule.OptionsModule: default_popmember_type default_popmember_type() = PopMember +constructorof(::Type{<:PopMember}) = PopMember + end diff --git a/src/Population.jl b/src/Population.jl index d3bd2b517..ad24ead3d 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -2,8 +2,7 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpression, string_tree -using ConstructionBase: constructorof +using DynamicExpressions: AbstractExpression, string_tree, constructorof using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost, update_baseline_loss! From 6d4d64f82e9a649d6db2359d9f094d1a24da053e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 02:10:24 +0100 Subject: [PATCH 04/13] refactor: force custom implementations of `create_child` --- src/PopMember.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index b01d6dfb9..eb3e11820 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -224,7 +224,7 @@ function create_child( complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs..., -) where {T,L,N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} +) where {T,L,N<:AbstractExpression{T},P<:PopMember{T,L,N}} actual_complexity = @something complexity compute_complexity(tree, options) return constructorof(P)( tree, @@ -253,7 +253,7 @@ function create_child( complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs..., -) where {T,L,N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} +) where {T,L,N<:AbstractExpression{T},P<:PopMember{T,L,N}} actual_complexity = @something complexity compute_complexity(tree, options) return constructorof(P)( tree, From bfcf1f4e9c023b2b2963583a6fc795a8a37e4005 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 15:50:14 +0100 Subject: [PATCH 05/13] refactor: add missing parts of AbstractPopMember interface --- src/ExpressionBuilder.jl | 18 ++-- src/HallOfFame.jl | 37 +++++-- src/Mutate.jl | 40 +++---- src/ParametricExpression.jl | 4 +- src/PopMember.jl | 30 ++---- src/SearchUtils.jl | 19 ++-- src/SymbolicRegression.jl | 20 +++- src/TemplateExpression.jl | 6 +- test/runtests.jl | 1 + test/test_abstract_popmember.jl | 186 ++++++++++++++++++++++++++++++++ 10 files changed, 286 insertions(+), 75 deletions(-) create mode 100644 test/test_abstract_popmember.jl diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index db6f5e82b..0867ab47a 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -11,7 +11,8 @@ using DynamicExpressions: using ..CoreModule: AbstractOptions, Dataset using ..HallOfFameModule: HallOfFame using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember, create_child +using ..ComplexityModule: compute_complexity import DynamicExpressions: get_operators import ..CoreModule: create_expression @@ -107,15 +108,16 @@ end return with_metadata(ex; init_params(options, dataset, ex, Val(true))...) end function embed_metadata( - member::PopMember, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L} - return PopMember( + member::PM, options::AbstractOptions, dataset::Dataset{T,L} + ) where {T,L,N,PM<:AbstractPopMember{T,L,N}} + return create_child( + member, embed_metadata(member.tree, options, dataset), member.cost, member.loss, - nothing; - member.ref, - member.parent, + options; + complexity=compute_complexity(member, options), + parent_ref=member.ref, deterministic=options.deterministic, ) end @@ -135,7 +137,7 @@ end end function embed_metadata( vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} + ) where {T,L,H<:Union{HallOfFame,Population,AbstractPopMember}} return map(Fix{2}(Fix{3}(embed_metadata, dataset), options), vec) end end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 4a1b7841e..d18f7fd36 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -73,17 +73,36 @@ function HallOfFame( options::AbstractOptions, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} base_tree = create_expression(init_value(T), options, dataset) + PM = options.popmember_type - return HallOfFame{T,L,typeof(base_tree),PopMember{T,L,typeof(base_tree)}}( + # Create a prototype member to get the concrete type + prototype = PM( + copy(base_tree), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + + PMtype = typeof(prototype) + + return HallOfFame{T,L,typeof(base_tree),PMtype}( [ - PopMember( - copy(base_tree), - L(0), - L(Inf), - options; - parent=-1, - deterministic=options.deterministic, - ) for i in 1:(options.maxsize) + if i == 1 + prototype + else + PM( + copy(base_tree), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + end for i in 1:(options.maxsize) ], [false for i in 1:(options.maxsize)], ) diff --git a/src/Mutate.jl b/src/Mutate.jl index 03772f9ca..412d7ea06 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -40,10 +40,10 @@ using ..MutationFunctionsModule: using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder -abstract type AbstractMutationResult{N<:AbstractExpression,P<:PopMember} end +abstract type AbstractMutationResult{N<:AbstractExpression,P<:AbstractPopMember} end """ - MutationResult{N<:AbstractExpression,P<:PopMember} + MutationResult{N<:AbstractExpression,P<:AbstractPopMember} Represents the result of a mutation operation in the genetic programming algorithm. This struct is used to return values from `mutate!` functions. @@ -160,7 +160,7 @@ Use this to modify how `mutate_constant` changes for an expression type. function condition_mutate_constant!( ::Type{<:AbstractExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) @@ -352,7 +352,7 @@ end @generated function _dispatch_mutations!( tree::AbstractExpression, - member::PopMember, + member::AbstractPopMember, mutation_choice::Symbol, weights::W, options::AbstractOptions; @@ -381,7 +381,7 @@ end mutation_weights::AbstractMutationWeights, options::AbstractOptions; kws..., - ) where {N<:AbstractExpression,P<:PopMember,S} + ) where {N<:AbstractExpression,P<:AbstractPopMember,S} Perform a mutation on the given `tree` and `member` using the specified mutation type `S`. Various `kws` are provided to access other data needed for some mutations. @@ -409,7 +409,7 @@ so it can always return immediately. """ function mutate!( ::N, ::P, ::Val{S}, ::AbstractMutationWeights, ::AbstractOptions; kws... -) where {N<:AbstractExpression,P<:PopMember,S} +) where {N<:AbstractExpression,P<:AbstractPopMember,S} return error("Unknown mutation choice: $S") end @@ -422,7 +422,7 @@ function mutate!( recorder::RecordType, temperature, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_constant(tree, temperature, options) @recorder recorder["type"] = "mutate_constant" return MutationResult{N,P}(; tree=tree) @@ -436,7 +436,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_operator(tree, options) @recorder recorder["type"] = "mutate_operator" return MutationResult{N,P}(; tree=tree) @@ -451,7 +451,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_feature(tree, nfeatures) @recorder recorder["type"] = "mutate_feature" return MutationResult{N,P}(; tree=tree) @@ -465,7 +465,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = swap_operands(tree) @recorder recorder["type"] = "swap_operands" return MutationResult{N,P}(; tree=tree) @@ -480,7 +480,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} if rand() < 0.5 tree = append_random_op(tree, options, nfeatures) @recorder recorder["type"] = "add_node:append" @@ -500,7 +500,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = insert_random_op(tree, options, nfeatures) @recorder recorder["type"] = "insert_node" return MutationResult{N,P}(; tree=tree) @@ -514,7 +514,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = delete_random_op!(tree) @recorder recorder["type"] = "delete_node" return MutationResult{N,P}(; tree=tree) @@ -528,7 +528,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = form_random_connection!(tree) @recorder recorder["type"] = "form_connection" return MutationResult{N,P}(; tree=tree) @@ -542,7 +542,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = break_random_connection!(tree) @recorder recorder["type"] = "break_connection" return MutationResult{N,P}(; tree=tree) @@ -556,7 +556,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = randomly_rotate_tree!(tree) @recorder recorder["type"] = "rotate_tree" return MutationResult{N,P}(; tree=tree) @@ -572,7 +572,7 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @assert options.should_simplify simplify_tree!(tree, options.operators) tree = combine_operators(tree, options.operators) @@ -593,7 +593,7 @@ function mutate!( curmaxsize, nfeatures, kws..., -) where {T,N<:AbstractExpression{T},P<:PopMember} +) where {T,N<:AbstractExpression{T},P<:AbstractPopMember} tree = randomize_tree(tree, curmaxsize, options, nfeatures) @recorder recorder["type"] = "randomize" return MutationResult{N,P}(; tree=tree) @@ -608,7 +608,7 @@ function mutate!( recorder::RecordType, dataset::Dataset, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} cur_member, new_num_evals = optimize_constants(dataset, member, options) @recorder recorder["type"] = "optimize" return MutationResult{N,P}(; @@ -625,7 +625,7 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @recorder begin recorder["type"] = "identity" recorder["result"] = "accept" diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 2717afbdc..b7c9ab2f6 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -24,7 +24,7 @@ using ..CoreModule: AbstractExpressionSpec, get_indices, ExpressionSpecModule as ES -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..LossFunctionsModule: LossFunctionsModule as LF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB @@ -102,7 +102,7 @@ end function MM.condition_mutate_constant!( ::Type{<:ParametricExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) diff --git a/src/PopMember.jl b/src/PopMember.jl index eb3e11820..698fa9627 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -198,35 +198,25 @@ function recompute_complexity!( return complexity end -# Interface for creating child members with custom field preservation """ - create_child(parent::P, tree, cost, loss, options; - complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember - -Create a new PopMember derived from a parent (mutation case). -Custom types should override to preserve their additional fields. + create_child(parent::P, tree::AbstractExpression{T}, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where {T,L,P<:PopMember{T,L}} -# Arguments -- `parent`: The parent member being mutated -- `tree`: The new expression tree -- `cost`: The new cost -- `loss`: The new loss -- `options`: The options -- `complexity`: The complexity (computed if not provided) -- `parent_ref`: Reference to parent for tracking +Create a new PopMember with a potentially different expression type. +Used by embed_metadata where the expression gains metadata. """ function create_child( parent::P, - tree::N, + tree::AbstractExpression{T}, cost::L, loss::L, options; complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs..., -) where {T,L,N<:AbstractExpression{T},P<:PopMember{T,L,N}} +) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) - return constructorof(P)( + return PopMember( tree, cost, loss, @@ -246,16 +236,16 @@ Custom types should override to blend their additional fields. """ function create_child( parents::Tuple{P,P}, - tree::N, + tree::AbstractExpression{T}, cost::L, loss::L, options; complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs..., -) where {T,L,N<:AbstractExpression{T},P<:PopMember{T,L,N}} +) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) - return constructorof(P)( + return PopMember( tree, cost, loss, diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 89eaa8cbf..014b036c4 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -17,7 +17,7 @@ using ..UtilsModule: subscriptify using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -583,8 +583,9 @@ The state of the search, including the populations, worker outputs, tasks, and channels. This is used to manage the search and keep track of runtime variables in a single struct. """ -Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} <: - AbstractSearchState{T,L,N} +Base.@kwdef struct SearchState{ + T,L,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N},WorkerOutputType,ChannelType +} <: AbstractSearchState{T,L,N} procs::Vector{Int} we_created_procs::Bool worker_output::Vector{Vector{WorkerOutputType}} @@ -592,16 +593,16 @@ Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,Cha channels::Vector{Vector{ChannelType}} worker_assignment::WorkerAssignments task_order::Vector{Tuple{Int,Int}} - halls_of_fame::Vector{HallOfFame{T,L,N}} - last_pops::Vector{Vector{Population{T,L,N}}} - best_sub_pops::Vector{Vector{Population{T,L,N}}} + halls_of_fame::Vector{HallOfFame{T,L,N,PM}} + last_pops::Vector{Vector{Population{T,L,N,PM}}} + best_sub_pops::Vector{Vector{Population{T,L,N,PM}}} all_running_search_statistics::Vector{RunningSearchStatistics} num_evals::Vector{Vector{Float64}} cycles_remaining::Vector{Int} cur_maxsizes::Vector{Int} stdin_reader::StdinReader record::Base.RefValue{RecordType} - seed_members::Vector{Vector{PopMember{T,L,N}}} + seed_members::Vector{Vector{PM}} end function save_to_file( @@ -718,7 +719,7 @@ end function update_hall_of_fame!( hall_of_fame::HallOfFame, members::Vector{PM}, options::AbstractOptions -) where {PM<:PopMember} +) where {PM<:AbstractPopMember} for member in members size = compute_complexity(member, options) valid_size = 0 < size <= options.maxsize @@ -793,7 +794,7 @@ function parse_guesses( guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L},D<:Dataset{T,L}} +) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} nout = length(datasets) out = [P[] for _ in 1:nout] guess_lists = _make_vector_vector(guesses, nout) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index acef3e7fb..ef7c86f0a 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -298,6 +298,7 @@ using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func using .PopMemberModule: AbstractPopMember, PopMember, reset_birth!, popmember_type +using .CoreModule.UtilsModule: get_birth_order using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve @@ -633,8 +634,19 @@ end example_dataset = first(datasets) example_ex = create_expression(init_value(T), options, example_dataset) NT = typeof(example_ex) - PopType = Population{T,L,NT} - HallOfFameType = HallOfFame{T,L,NT} + # Create a prototype member to get the concrete type + prototype_member = options.popmember_type( + copy(example_ex), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + PMType = typeof(prototype_member) + PopType = Population{T,L,NT,PMType} + HallOfFameType = HallOfFame{T,L,NT,PMType} WorkerOutputType = get_worker_output_type( Val(ropt.parallelism), PopType, HallOfFameType ) @@ -692,9 +704,9 @@ end j in 1:nout ] - seed_members = [PopMember{T,L,NT}[] for j in 1:nout] + seed_members = [Vector{PMType}() for j in 1:nout] - return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; + return SearchState{T,L,NT,PMType,WorkerOutputType,ChannelType}(; procs=procs, we_created_procs=we_created_procs, worker_output=worker_output, diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index af86f4825..0f87cb9b0 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -52,7 +52,7 @@ using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..ComposableExpressionModule: ComposableExpression, ValidVector struct ParamVector{T} <: AbstractVector{T} @@ -745,7 +745,7 @@ function MM.condition_mutation_weights!( @nospecialize(options::AbstractOptions), curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:TemplateExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:TemplateExpression,P<:AbstractPopMember{T,L,N}} if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 weights.break_connection = 0.0 @@ -828,7 +828,7 @@ end function MM.condition_mutate_constant!( ::Type{<:TemplateExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) diff --git a/test/runtests.jl b/test/runtests.jl index 7aef02fea..f4bd81111 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -164,6 +164,7 @@ end end include("test_abstract_numbers.jl") +include("test_abstract_popmember.jl") include("test_logging.jl") include("test_pretty_printing.jl") diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl new file mode 100644 index 000000000..318d260a9 --- /dev/null +++ b/test/test_abstract_popmember.jl @@ -0,0 +1,186 @@ +@testitem "Custom AbstractPopMember implementation" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions + using Test + + import SymbolicRegression.PopMemberModule: create_child + + # Define a custom PopMember that tracks generation count + mutable struct CustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N} + tree::N + cost::L + loss::L + birth::Int + complexity::Int + ref::Int + parent::Int + generation::Int # Custom field to track generation + end + + # # Direct constructor that matches field order + function CustomPopMember( + tree::N, + cost::L, + loss::L, + birth::Int, + complexity::Int, + ref::Int, + parent::Int, + generation::Int, + ) where {T,L,N<:AbstractExpression{T}} + return CustomPopMember{T,L,N}( + tree, cost, loss, birth, complexity, ref, parent, generation + ) + end + + function CustomPopMember( + tree::N, + cost::L, + loss::L, + options, + complexity::Int; + parent=-1, + deterministic=nothing, + ) where {T,L,N<:AbstractExpression{T}} + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + 0, # Initial generation + ) + end + + # Constructor for Population initialization (dataset, tree, options) + function CustomPopMember( + dataset::SymbolicRegression.Dataset, tree, options; parent=-1, deterministic=nothing + ) + ex = SymbolicRegression.create_expression(tree, options, dataset) + complexity = SymbolicRegression.compute_complexity(ex, options) + cost, loss = SymbolicRegression.eval_cost( + dataset, ex, options; complexity=complexity + ) + + return CustomPopMember( + ex, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + 0, # Initial generation + ) + end + + DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember + + # Define copy for CustomPopMember + function Base.copy(p::CustomPopMember) + return CustomPopMember( + copy(p.tree), + copy(p.cost), + copy(p.loss), + copy(p.birth), + copy(getfield(p, :complexity)), + copy(p.ref), + copy(p.parent), + copy(p.generation), + ) + end + + function create_child( + parent::CustomPopMember{T,L}, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + kwargs..., + ) where {T,L} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=options.deterministic), + actual_complexity, + abs(rand(Int)), + parent_ref, + parent.generation + 1, + ) + end + + function create_child( + parents::Tuple{<:CustomPopMember,<:CustomPopMember}, + tree::N, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + kwargs..., + ) where {T,L,N<:AbstractExpression{T}} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + max_generation = max(parents[1].generation, parents[2].generation) + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.CoreModule.UtilsModule.get_birth_order(; + deterministic=options.deterministic + ), + actual_complexity, + abs(rand(Int)), + parent_ref, + max_generation + 1, + ) + end + + # Test that we can run equation_search with CustomPopMember + X = randn(Float32, 2, 100) + y = @. X[1, :]^2 - X[2, :] + + options = SymbolicRegression.Options(; + binary_operators=[+, -], + populations=1, + population_size=20, + maxsize=5, + popmember_type=CustomPopMember, + deterministic=true, + seed=0, + ) + + # Test that options were created with correct type + @test options.popmember_type == CustomPopMember + + hall_of_fame = equation_search( + X, y; options=options, niterations=2, parallelism=:serial + ) + + # Verify that we got results + @test sum(hall_of_fame.exists) > 0 + + # Verify that the members are CustomPopMember + for i in eachindex(hall_of_fame.members, hall_of_fame.exists) + if hall_of_fame.exists[i] + @test hall_of_fame.members[i] isa CustomPopMember + # Check that generation field exists + @test hall_of_fame.members[i].generation >= 0 + end + end + + # Verify we can extract the best member + best_idx = findlast(hall_of_fame.exists) + @test !isnothing(best_idx) + best_member = hall_of_fame.members[best_idx] + @test best_member isa CustomPopMember +end From 690e0f1ef37e1fda545a15bb063e1e1bbf07f6b8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 15:53:40 +0100 Subject: [PATCH 06/13] refactor: move forward forward decl --- src/Options.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Options.jl b/src/Options.jl index c746b1332..977c3fc96 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -40,9 +40,6 @@ using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutatio import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization using ..UtilsModule: @save_kwargs, @ignore - -# Forward declaration - will be defined in PopMemberModule -function default_popmember_type end using ..ExpressionSpecModule: AbstractExpressionSpec, ExpressionSpec, @@ -227,6 +224,8 @@ recommend_loss_function_expression(expression_type) = false create_mutation_weights(w::AbstractMutationWeights) = w create_mutation_weights(w::NamedTuple) = MutationWeights(; w...) +function default_popmember_type end + @unstable function with_max_degree_from_context( node_type, user_provided_operators, operators ) From 78b6b80a2c5536cff8c9aa588498dad7e48f7e87 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 16:15:58 +0100 Subject: [PATCH 07/13] refactor: move imports to top --- src/PopMember.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index 698fa9627..050e2344f 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -4,6 +4,7 @@ using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree import DynamicExpressions: constructorof using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression +import ..CoreModule.OptionsModule: default_popmember_type import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_cost @@ -259,10 +260,7 @@ end # Function to extract PopMember type from Population or HallOfFame types function popmember_type end -# Default PopMember type for Options -import ..CoreModule.OptionsModule: default_popmember_type default_popmember_type() = PopMember - constructorof(::Type{<:PopMember}) = PopMember end From fd3f5a5f82b24ffded01f89b16e2ce26242c1a06 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Sep 2025 18:16:52 +0100 Subject: [PATCH 08/13] fix: mark unstable --- src/PopMember.jl | 4 ++-- test/test_abstract_popmember.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index 050e2344f..7e5ca84e9 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -260,7 +260,7 @@ end # Function to extract PopMember type from Population or HallOfFame types function popmember_type end -default_popmember_type() = PopMember -constructorof(::Type{<:PopMember}) = PopMember +@unstable default_popmember_type() = PopMember +@unstable constructorof(::Type{<:PopMember}) = PopMember end diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl index 318d260a9..7e6aff8b5 100644 --- a/test/test_abstract_popmember.jl +++ b/test/test_abstract_popmember.jl @@ -2,6 +2,7 @@ using SymbolicRegression using DynamicExpressions using Test + using DispatchDoctor: @unstable import SymbolicRegression.PopMemberModule: create_child @@ -76,7 +77,7 @@ ) end - DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember + @unstable DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember # Define copy for CustomPopMember function Base.copy(p::CustomPopMember) From 26665da252a2c87a6e2dee01f5e2f856a7105ce7 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 12:16:36 +0100 Subject: [PATCH 09/13] fix: `parse_guesses` for custom AbstractPopMember --- src/ExpressionBuilder.jl | 1 - src/MLJInterface.jl | 4 ++-- src/PopMember.jl | 12 ++++++++---- src/SearchUtils.jl | 31 +++++++++++++++++++++++++++---- test/test_abstract_popmember.jl | 2 -- 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 33b1cc09e..99de5bdc6 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -118,7 +118,6 @@ end options; complexity=compute_complexity(member, options), parent_ref=member.ref, - deterministic=options.deterministic, ) end function embed_metadata( diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 4ca68ed65..ccc6c33ee 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -39,8 +39,8 @@ using ..CoreModule: ExpressionSpec, get_expression_type, check_warm_start_compatibility -using ..CoreModule.OptionsModule: - DEFAULT_OPTIONS, OPTION_DESCRIPTIONS, default_popmember_type +using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS +using ..PopMemberModule: default_popmember_type using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore diff --git a/src/PopMember.jl b/src/PopMember.jl index 7e5ca84e9..ec350686e 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -201,7 +201,7 @@ end """ create_child(parent::P, tree::AbstractExpression{T}, cost, loss, options; - complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where {T,L,P<:PopMember{T,L}} + complexity::Union{Int,Nothing}=nothing, parent_ref) where {T,L,P<:PopMember{T,L}} Create a new PopMember with a potentially different expression type. Used by embed_metadata where the expression gains metadata. @@ -214,7 +214,6 @@ function create_child( options; complexity::Union{Int,Nothing}=nothing, parent_ref, - kwargs..., ) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) return PopMember( @@ -230,7 +229,7 @@ end """ create_child(parents::Tuple{P,P}, tree, cost, loss, options; - complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember + complexity::Union{Int,Nothing}=nothing, parent_ref) where P<:AbstractPopMember Create a new PopMember from two parents (crossover case). Custom types should override to blend their additional fields. @@ -243,7 +242,6 @@ function create_child( options; complexity::Union{Int,Nothing}=nothing, parent_ref, - kwargs..., ) where {T,L,P<:PopMember{T,L}} actual_complexity = @something complexity compute_complexity(tree, options) return PopMember( @@ -263,4 +261,10 @@ function popmember_type end @unstable default_popmember_type() = PopMember @unstable constructorof(::Type{<:PopMember}) = PopMember +@inline function with_expression_type( + ::Type{<:PopMember{T,L}}, ::Type{N} +) where {T,L,N<:AbstractExpression{T}} + return PopMember{T,L,N} +end + end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index ae248951a..1c71d3723 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -12,12 +12,18 @@ using DispatchDoctor: @unstable using Logging: AbstractLogger using DynamicExpressions: - AbstractExpression, string_tree, parse_expression, EvalOptions, with_type_parameters + AbstractExpression, + string_tree, + parse_expression, + EvalOptions, + with_type_parameters, + constructorof using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features +using ..CoreModule: + Dataset, AbstractOptions, Options, RecordType, max_features, create_expression using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember, AbstractPopMember +using ..PopMemberModule: PopMember, AbstractPopMember, with_expression_type using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -800,7 +806,9 @@ function parse_guesses( dataset = datasets[j] for g in guess_lists[j] ex = _parse_guess_expression(T, g, dataset, options) - member = PopMember(dataset, ex, options; deterministic=options.deterministic) + member = constructorof(P)( + dataset, ex, options; deterministic=options.deterministic + ) if options.should_optimize_constants member, _ = optimize_constants(dataset, member, options) end @@ -818,6 +826,21 @@ function parse_guesses( end return out end + +# Deal with non-concrete PopMember types +function parse_guesses( + ::Type{P}, + guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, + datasets::Vector{D}, + options::AbstractOptions, +) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} + NodeType = with_type_parameters(options.node_type, T) + N = Base.promote_op(create_expression, NodeType, typeof(options), D) + N === Any && error("Failed to infer expression type") + ConcreteP = with_expression_type(P, N) + return parse_guesses(ConcreteP, guesses, datasets, options) +end + function _make_vector_vector(guesses, nout) if nout == 1 if guesses isa AbstractVector{<:AbstractVector} diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl index 7e6aff8b5..058cbc02c 100644 --- a/test/test_abstract_popmember.jl +++ b/test/test_abstract_popmember.jl @@ -101,7 +101,6 @@ options; complexity::Union{Int,Nothing}=nothing, parent_ref, - kwargs..., ) where {T,L} actual_complexity = @something complexity SymbolicRegression.compute_complexity( tree, options @@ -126,7 +125,6 @@ options; complexity::Union{Int,Nothing}=nothing, parent_ref, - kwargs..., ) where {T,L,N<:AbstractExpression{T}} actual_complexity = @something complexity SymbolicRegression.compute_complexity( tree, options From f1053282ea45125c3dce31d998fd8e586f9098e1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 12:26:02 +0100 Subject: [PATCH 10/13] fix: mark unstable to avoid recursion --- src/SearchUtils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 1c71d3723..461f0a1da 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -793,7 +793,7 @@ end """Parse user-provided guess expressions and convert them into optimized `PopMember` objects for each output dataset.""" -function parse_guesses( +@unstable function parse_guesses( ::Type{P}, guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, @@ -828,7 +828,7 @@ function parse_guesses( end # Deal with non-concrete PopMember types -function parse_guesses( +@unstable function parse_guesses( ::Type{P}, guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, From f45f0a1115a570ea3295402f4692e1453610d138 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 12:51:37 +0100 Subject: [PATCH 11/13] fix: allow `_get_cost` to be generic --- src/Population.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Population.jl b/src/Population.jl index ad24ead3d..00f603258 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -171,7 +171,7 @@ function _best_of_sample( end return members[chosen_idx] end -_get_cost(member::PopMember) = member.cost +_get_cost(member::AbstractPopMember) = member.cost const CACHED_WEIGHTS = let init_k = collect(0:5), From 2dc06456e3b1ecf0832d1cc67597eb3aed9120f8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 13:32:16 +0100 Subject: [PATCH 12/13] fix: try to avoid recursive type inference --- src/SearchUtils.jl | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 461f0a1da..ced0707c5 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -798,6 +798,29 @@ end guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, +) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} + return _parse_guesses_impl(P, guesses, datasets, options) +end + +# Deal with non-concrete PopMember types +@unstable function parse_guesses( + ::Type{P}, + guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, + datasets::Vector{D}, + options::AbstractOptions, +) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} + NodeType = with_type_parameters(options.node_type, T) + N = Base.promote_op(create_expression, NodeType, typeof(options), D) + N in (Any, Union{}) && error("Failed to infer expression type") + ConcreteP = with_expression_type(P, N) + return _parse_guesses_impl(ConcreteP, guesses, datasets, options) +end + +@inline function _parse_guesses_impl( + ::Type{P}, + guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, + datasets::Vector{D}, + options::AbstractOptions, ) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} nout = length(datasets) out = [P[] for _ in 1:nout] @@ -827,20 +850,6 @@ end return out end -# Deal with non-concrete PopMember types -@unstable function parse_guesses( - ::Type{P}, - guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, - datasets::Vector{D}, - options::AbstractOptions, -) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} - NodeType = with_type_parameters(options.node_type, T) - N = Base.promote_op(create_expression, NodeType, typeof(options), D) - N === Any && error("Failed to infer expression type") - ConcreteP = with_expression_type(P, N) - return parse_guesses(ConcreteP, guesses, datasets, options) -end - function _make_vector_vector(guesses, nout) if nout == 1 if guesses isa AbstractVector{<:AbstractVector} From 4d3e1cf371717facfc92fbbeb2e6a00a42d50be1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 6 Oct 2025 19:19:33 +0100 Subject: [PATCH 13/13] refactor: `infer_popmember_type` --- src/PopMember.jl | 12 +++++++++++- src/SearchUtils.jl | 26 +++++++++++--------------- src/SymbolicRegression.jl | 22 ++++++---------------- test/test_abstract_popmember.jl | 7 +++++++ 4 files changed, 35 insertions(+), 32 deletions(-) diff --git a/src/PopMember.jl b/src/PopMember.jl index ec350686e..71f8707de 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -2,7 +2,7 @@ module PopMemberModule using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree -import DynamicExpressions: constructorof +import DynamicExpressions: constructorof, with_type_parameters using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression import ..CoreModule.OptionsModule: default_popmember_type import ..ComplexityModule: compute_complexity @@ -267,4 +267,14 @@ function popmember_type end return PopMember{T,L,N} end +@inline function with_type_parameters( + ::Type{<:PopMember}, ::Type{T}, ::Type{L}, ::Type{N} +) where {T,L,N} + return PopMember{T,L,N} +end + +@inline function expression_type(::Type{<:AbstractPopMember{<:Any,<:Any,N}}) where {N} + return N +end + end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index ced0707c5..05efe4c09 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -23,7 +23,7 @@ using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features, create_expression using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember, AbstractPopMember, with_expression_type +using ..PopMemberModule: PopMember, AbstractPopMember using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -34,6 +34,15 @@ using ..CheckConstraintsModule: check_constraints function logging_callback! end +@unstable @inline function infer_popmember_type( + ::Type{T}, ::Type{L}, ::Type{D}, options +) where {T,L,D<:Dataset} + NodeType = with_type_parameters(options.node_type, T) + N = Base.promote_op(create_expression, NodeType, typeof(options), D) + N in (Any, Union{}) && error("Failed to infer expression type") + return with_type_parameters(options.popmember_type, T, L, N) +end + """ @filtered_async expr @@ -793,26 +802,13 @@ end """Parse user-provided guess expressions and convert them into optimized `PopMember` objects for each output dataset.""" -@unstable function parse_guesses( - ::Type{P}, - guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, - datasets::Vector{D}, - options::AbstractOptions, -) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} - return _parse_guesses_impl(P, guesses, datasets, options) -end - -# Deal with non-concrete PopMember types @unstable function parse_guesses( ::Type{P}, guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, ) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} - NodeType = with_type_parameters(options.node_type, T) - N = Base.promote_op(create_expression, NodeType, typeof(options), D) - N in (Any, Union{}) && error("Failed to infer expression type") - ConcreteP = with_expression_type(P, N) + ConcreteP = infer_popmember_type(T, L, D, options) return _parse_guesses_impl(ConcreteP, guesses, datasets, options) end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 40119fc70..4e9443f23 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -297,7 +297,8 @@ using .MutationFunctionsModule: using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func -using .PopMemberModule: AbstractPopMember, PopMember, reset_birth!, popmember_type +using .PopMemberModule: + AbstractPopMember, PopMember, reset_birth!, popmember_type, expression_type using .CoreModule.UtilsModule: get_birth_order using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: @@ -339,7 +340,8 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame!, parse_guesses, - logging_callback! + logging_callback!, + infer_popmember_type using .LoggingModule: AbstractSRLogger, SRLogger, get_logger using .TemplateExpressionModule: TemplateExpression, TemplateStructure, TemplateExpressionSpec, ParamVector, has_params @@ -631,20 +633,8 @@ end @recorder record["options"] = "$(options)" nout = length(datasets) - example_dataset = first(datasets) - example_ex = create_expression(init_value(T), options, example_dataset) - NT = typeof(example_ex) - # Create a prototype member to get the concrete type - prototype_member = options.popmember_type( - copy(example_ex), - L(0), - L(Inf), - options, - 1; # complexity - parent=-1, - deterministic=options.deterministic, - ) - PMType = typeof(prototype_member) + PMType = infer_popmember_type(T, L, D, options) + NT = expression_type(PMType) PopType = Population{T,L,NT,PMType} HallOfFameType = HallOfFame{T,L,NT,PMType} WorkerOutputType = get_worker_output_type( diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl index 058cbc02c..eabc6886d 100644 --- a/test/test_abstract_popmember.jl +++ b/test/test_abstract_popmember.jl @@ -79,6 +79,13 @@ @unstable DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember + # Define with_type_parameters for CustomPopMember + @unstable function DynamicExpressions.with_type_parameters( + ::Type{<:CustomPopMember}, ::Type{T}, ::Type{L}, ::Type{N} + ) where {T,L,N} + return CustomPopMember{T,L,N} + end + # Define copy for CustomPopMember function Base.copy(p::CustomPopMember) return CustomPopMember(