diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl index 585a6ef78..7fa7798d8 100644 --- a/src/ConstantOptimization.jl +++ b/src/ConstantOptimization.jl @@ -17,7 +17,7 @@ using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options, dataset_fraction using ..UtilsModule: get_birth_order, PerTaskCache, stable_get! using ..LossFunctionsModule: eval_loss, loss_to_cost -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember function can_optimize(::AbstractExpression{T}, options) where {T} return can_optimize(T, options) @@ -31,7 +31,7 @@ end member::P, options::AbstractOptions; rng::AbstractRNG=default_rng(), -)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,N,P<:AbstractPopMember{T,L,N}} can_optimize(member.tree, options) || return (member, 0.0) nconst = count_constants_for_optimization(member.tree) nconst == 0 && return (member, 0.0) @@ -63,7 +63,7 @@ count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex) function _optimize_constants( dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T,L,N,P<:AbstractPopMember{T,L,N}} tree = member.tree x0, refs = get_scalar_constants(tree) @assert count_constants_for_optimization(tree) == length(x0) @@ -76,7 +76,7 @@ function _optimize_constants( end function _optimize_constants_inner( f::F, fg!::G, x0, refs, dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {F,G,T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {F,G,T,L,N,P<:AbstractPopMember{T,L,N}} obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing f else diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 4de4e56a2..99de5bdc6 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -11,7 +11,8 @@ using DynamicExpressions: using ..CoreModule: AbstractOptions, Dataset using ..HallOfFameModule: HallOfFame using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember, create_child +using ..ComplexityModule: compute_complexity import DynamicExpressions: get_operators import ..CoreModule: create_expression @@ -107,16 +108,16 @@ end return with_metadata(ex; init_params(options, dataset, ex, Val(true))...) end function embed_metadata( - member::PopMember, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L} - return PopMember( + member::PM, options::AbstractOptions, dataset::Dataset{T,L} + ) where {T,L,N,PM<:AbstractPopMember{T,L,N}} + return create_child( + member, embed_metadata(member.tree, options, dataset), member.cost, member.loss, - nothing; - member.ref, - member.parent, - deterministic=options.deterministic, + options; + complexity=compute_complexity(member, options), + parent_ref=member.ref, ) end function embed_metadata( @@ -135,7 +136,7 @@ end end function embed_metadata( vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} + ) where {T,L,H<:Union{HallOfFame,Population,AbstractPopMember}} return map(Fix{2}(Fix{3}(embed_metadata, dataset), options), vec) end end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index c90fbe3a4..d18f7fd36 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -6,12 +6,13 @@ using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression, init_value using ..ComplexityModule: compute_complexity -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember +import ..PopMemberModule: popmember_type using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING using Printf: @sprintf """ - HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} + HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N}} List of the best members seen all time in `.members`, with `.members[c]` being the best member seen at complexity c. Including only the members which actually @@ -19,15 +20,19 @@ have been set, you can run `.members[exists]`. # Fields -- `members::Array{PopMember{T,L,N},1}`: List of the best members seen all time. +- `members::Array{PM,1}`: List of the best members seen all time. These are ordered by complexity, with `.members[1]` the member with complexity 1. - `exists::Array{Bool,1}`: Whether the member at the given complexity has been set. """ -struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct HallOfFame{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} exists::Array{Bool,1} #Whether it has been set end -function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N} +function Base.show( + io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N,PM} +) where {T,L,N,PM} println(io, "HallOfFame{...}:") for i in eachindex(hof.members, hof.exists) s_member, s_exists = if hof.exists[i] @@ -47,8 +52,8 @@ function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where end return nothing end -function Base.eltype(::Union{HOF,Type{HOF}}) where {T,L,N,HOF<:HallOfFame{T,L,N}} - return PopMember{T,L,N} +function Base.eltype(::Union{HOF,Type{HOF}}) where {T,L,N,PM,HOF<:HallOfFame{T,L,N,PM}} + return PM end """ @@ -68,17 +73,36 @@ function HallOfFame( options::AbstractOptions, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} base_tree = create_expression(init_value(T), options, dataset) + PM = options.popmember_type - return HallOfFame{T,L,typeof(base_tree)}( + # Create a prototype member to get the concrete type + prototype = PM( + copy(base_tree), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + + PMtype = typeof(prototype) + + return HallOfFame{T,L,typeof(base_tree),PMtype}( [ - PopMember( - copy(base_tree), - L(0), - L(Inf), - options; - parent=-1, - deterministic=options.deterministic, - ) for i in 1:(options.maxsize) + if i == 1 + prototype + else + PM( + copy(base_tree), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + end for i in 1:(options.maxsize) ], [false for i in 1:(options.maxsize)], ) @@ -93,11 +117,10 @@ end """ calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,P}) where {T<:DATA_TYPE,L<:LOSS_TYPE} """ -function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} +function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N,PM}) where {T,L,N,PM} # TODO - remove dataset from args. - P = PopMember{T,L,N} # Dominating pareto curve - must be better than all simpler equations - dominating = P[] + dominating = PM[] for size in eachindex(hallOfFame.members) if !hallOfFame.exists[size] continue @@ -276,4 +299,7 @@ function format_hall_of_fame(hof::AbstractVector{<:HallOfFame}, options) end # TODO: Re-use this in `string_dominating_pareto_curve` +# Type accessor for HallOfFame +popmember_type(::Type{<:HallOfFame{T,L,N,PM}}) where {T,L,N,PM} = PM + end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index bb2fddfb0..ccc6c33ee 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -40,6 +40,7 @@ using ..CoreModule: get_expression_type, check_warm_start_compatibility using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS +using ..PopMemberModule: default_popmember_type using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore diff --git a/src/Migration.jl b/src/Migration.jl index f7fe61b89..3988b81ea 100644 --- a/src/Migration.jl +++ b/src/Migration.jl @@ -2,7 +2,7 @@ module MigrationModule using ..CoreModule: AbstractOptions using ..PopulationModule: Population -using ..PopMemberModule: PopMember, reset_birth! +using ..PopMemberModule: AbstractPopMember, PopMember, reset_birth! using ..UtilsModule: poisson_sample """ @@ -14,7 +14,7 @@ to do so. The original migrant population is not modified. Pass with, e.g., """ function migrate!( migration::Pair{Vector{PM},P}, options::AbstractOptions; frac::AbstractFloat -) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}} +) where {T,L,N,PM<:AbstractPopMember{T,L,N},P<:Population{T,L,N,PM}} base_pop = migration.second population_size = length(base_pop.members) mean_number_replaced = population_size * frac diff --git a/src/Mutate.jl b/src/Mutate.jl index f5fd88457..412d7ea06 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -22,7 +22,7 @@ using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember, create_child using ..MutationFunctionsModule: mutate_constant, mutate_operator, @@ -40,10 +40,10 @@ using ..MutationFunctionsModule: using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder -abstract type AbstractMutationResult{N<:AbstractExpression,P<:PopMember} end +abstract type AbstractMutationResult{N<:AbstractExpression,P<:AbstractPopMember} end """ - MutationResult{N<:AbstractExpression,P<:PopMember} + MutationResult{N<:AbstractExpression,P<:AbstractPopMember} Represents the result of a mutation operation in the genetic programming algorithm. This struct is used to return values from `mutate!` functions. @@ -61,7 +61,8 @@ This struct encapsulates the result of a mutation operation. Either a new expres Return the `member` if you want to return immediately, and have computed the loss value as part of the mutation. """ -struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationResult{N,P} +struct MutationResult{N<:AbstractExpression,P<:AbstractPopMember} <: + AbstractMutationResult{N,P} tree::Union{N,Nothing} member::Union{P,Nothing} num_evals::Float64 @@ -73,7 +74,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes member::Union{_P,Nothing}=nothing, num_evals::Float64=0.0, return_immediately::Bool=false, - ) where {_N<:AbstractExpression,_P<:PopMember} + ) where {_N<:AbstractExpression,_P<:AbstractPopMember} @assert( (tree === nothing) ⊻ (member === nothing), "Mutation result must return either a tree or a pop member, not both" @@ -83,7 +84,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes end """ - condition_mutation_weights!(weights::AbstractMutationWeights, member::PopMember, options::AbstractOptions, curmaxsize::Int, nfeatures::Int) + condition_mutation_weights!(weights::AbstractMutationWeights, member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, nfeatures::Int) Adjusts the mutation weights based on the properties of the current member and options. @@ -93,7 +94,7 @@ Note that the weights were already copied, so you don't need to worry about muta # Arguments - `weights::AbstractMutationWeights`: The mutation weights to be adjusted. -- `member::PopMember`: The current population member being mutated. +- `member::AbstractPopMember`: The current population member being mutated. - `options::AbstractOptions`: The options that guide the mutation process. - `curmaxsize::Int`: The current maximum size constraint for the member's expression tree. - `nfeatures::Int`: The number of features available in the dataset. @@ -104,7 +105,7 @@ function condition_mutation_weights!( options::AbstractOptions, curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:AbstractExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:AbstractExpression,P<:AbstractPopMember{T,L,N}} tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 @@ -159,7 +160,7 @@ Use this to modify how `mutate_constant` changes for an expression type. function condition_mutate_constant!( ::Type{<:AbstractExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) @@ -181,7 +182,7 @@ end tmp_recorder::RecordType, )::Tuple{ P,Bool,Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} parent_ref = member.ref num_evals = 0.0 @@ -253,14 +254,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -277,14 +277,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -321,14 +320,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -339,25 +337,22 @@ end tmp_recorder["reason"] = "pass" end mutation_accepted = true - return ( - PopMember( - tree, - after_cost, - after_loss, - options, - newSize; - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, + new_member = create_child( + member, + tree, + after_cost, + after_loss, + options; + complexity=newSize, + parent_ref=parent_ref, ) + return (new_member, mutation_accepted, num_evals) end end @generated function _dispatch_mutations!( tree::AbstractExpression, - member::PopMember, + member::AbstractPopMember, mutation_choice::Symbol, weights::W, options::AbstractOptions; @@ -386,7 +381,7 @@ end mutation_weights::AbstractMutationWeights, options::AbstractOptions; kws..., - ) where {N<:AbstractExpression,P<:PopMember,S} + ) where {N<:AbstractExpression,P<:AbstractPopMember,S} Perform a mutation on the given `tree` and `member` using the specified mutation type `S`. Various `kws` are provided to access other data needed for some mutations. @@ -414,7 +409,7 @@ so it can always return immediately. """ function mutate!( ::N, ::P, ::Val{S}, ::AbstractMutationWeights, ::AbstractOptions; kws... -) where {N<:AbstractExpression,P<:PopMember,S} +) where {N<:AbstractExpression,P<:AbstractPopMember,S} return error("Unknown mutation choice: $S") end @@ -427,7 +422,7 @@ function mutate!( recorder::RecordType, temperature, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_constant(tree, temperature, options) @recorder recorder["type"] = "mutate_constant" return MutationResult{N,P}(; tree=tree) @@ -441,7 +436,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_operator(tree, options) @recorder recorder["type"] = "mutate_operator" return MutationResult{N,P}(; tree=tree) @@ -456,7 +451,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_feature(tree, nfeatures) @recorder recorder["type"] = "mutate_feature" return MutationResult{N,P}(; tree=tree) @@ -470,7 +465,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = swap_operands(tree) @recorder recorder["type"] = "swap_operands" return MutationResult{N,P}(; tree=tree) @@ -485,7 +480,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} if rand() < 0.5 tree = append_random_op(tree, options, nfeatures) @recorder recorder["type"] = "add_node:append" @@ -505,7 +500,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = insert_random_op(tree, options, nfeatures) @recorder recorder["type"] = "insert_node" return MutationResult{N,P}(; tree=tree) @@ -519,7 +514,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = delete_random_op!(tree) @recorder recorder["type"] = "delete_node" return MutationResult{N,P}(; tree=tree) @@ -533,7 +528,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = form_random_connection!(tree) @recorder recorder["type"] = "form_connection" return MutationResult{N,P}(; tree=tree) @@ -547,7 +542,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = break_random_connection!(tree) @recorder recorder["type"] = "break_connection" return MutationResult{N,P}(; tree=tree) @@ -561,7 +556,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = randomly_rotate_tree!(tree) @recorder recorder["type"] = "rotate_tree" return MutationResult{N,P}(; tree=tree) @@ -577,22 +572,15 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @assert options.should_simplify simplify_tree!(tree, options.operators) tree = combine_operators(tree, options.operators) @recorder recorder["type"] = "simplify" - return MutationResult{N,P}(; - member=PopMember( - tree, - member.cost, - member.loss, - options; - parent=parent_ref, - deterministic=options.deterministic, - ), - return_immediately=true, + new_member = create_child( + member, tree, member.cost, member.loss, options; parent_ref=parent_ref ) + return MutationResult{N,P}(; member=new_member, return_immediately=true) end function mutate!( @@ -605,7 +593,7 @@ function mutate!( curmaxsize, nfeatures, kws..., -) where {T,N<:AbstractExpression{T},P<:PopMember} +) where {T,N<:AbstractExpression{T},P<:AbstractPopMember} tree = randomize_tree(tree, curmaxsize, options, nfeatures) @recorder recorder["type"] = "randomize" return MutationResult{N,P}(; tree=tree) @@ -620,7 +608,7 @@ function mutate!( recorder::RecordType, dataset::Dataset, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} cur_member, new_num_evals = optimize_constants(dataset, member, options) @recorder recorder["type"] = "optimize" return MutationResult{N,P}(; @@ -637,21 +625,15 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @recorder begin recorder["type"] = "identity" recorder["result"] = "accept" recorder["reason"] = "identity" end return MutationResult{N,P}(; - member=PopMember( - tree, - member.cost, - member.loss, - options, - compute_complexity(tree, options); - parent=parent_ref, - deterministic=options.deterministic, + member=create_child( + member, tree, member.cost, member.loss, options; parent_ref=parent_ref ), return_immediately=true, ) @@ -665,7 +647,7 @@ function crossover_generation( curmaxsize::Int, options::AbstractOptions; recorder::RecordType=RecordType(), -)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} +)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:AbstractPopMember{T,L,N}} tree1 = member1.tree tree2 = member2.tree crossover_accepted = false @@ -704,23 +686,23 @@ function crossover_generation( ) num_evals += 2 * dataset_fraction(dataset) - baby1 = PopMember( + baby1 = create_child( + (member1, member2), child_tree1::AbstractExpression, after_cost1, after_loss1, - options, - afterSize1; - parent=member1.ref, - deterministic=options.deterministic, + options; + complexity=afterSize1, + parent_ref=member1.ref, )::P - baby2 = PopMember( + baby2 = create_child( + (member1, member2), child_tree2::AbstractExpression, after_cost2, after_loss2, - options, - afterSize2; - parent=member2.ref, - deterministic=options.deterministic, + options; + complexity=afterSize2, + parent_ref=member2.ref, )::P @recorder begin diff --git a/src/Options.jl b/src/Options.jl index cb8d3bbf3..6ee7c4222 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -225,6 +225,8 @@ recommend_loss_function_expression(expression_type) = false create_mutation_weights(w::AbstractMutationWeights) = w create_mutation_weights(w::NamedTuple) = MutationWeights(; w...) +function default_popmember_type end + @unstable function with_max_degree_from_context( node_type, user_provided_operators, operators ) @@ -652,6 +654,7 @@ $(OPTION_DESCRIPTIONS) terminal_width::Union{Nothing,Integer}=nothing, use_recorder::Bool=false, recorder_file::AbstractString="pysr_recorder.json", + popmember_type::Type=default_popmember_type(), ### Not search options; just construction options: define_helper_functions::Bool=true, ######################################### @@ -1031,6 +1034,7 @@ $(OPTION_DESCRIPTIONS) expression_type, typeof(expression_options), typeof(set_mutation_weights), + popmember_type, turbo, bumper, deprecated_return_state::Union{Bool,Nothing}, @@ -1104,6 +1108,7 @@ $(OPTION_DESCRIPTIONS) deterministic, define_helper_functions, use_recorder, + popmember_type, ) return options diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 2c3046204..495035760 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -183,6 +183,7 @@ struct Options{ E<:AbstractExpression, EO<:NamedTuple, MW<:AbstractMutationWeights, + PM, _turbo, _bumper, _return_state, @@ -256,6 +257,7 @@ struct Options{ deterministic::Bool define_helper_functions::Bool use_recorder::Bool + popmember_type::Type{PM} end function Base.print(io::IO, @nospecialize(options::Options)) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 2717afbdc..b7c9ab2f6 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -24,7 +24,7 @@ using ..CoreModule: AbstractExpressionSpec, get_indices, ExpressionSpecModule as ES -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..LossFunctionsModule: LossFunctionsModule as LF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB @@ -102,7 +102,7 @@ end function MM.condition_mutate_constant!( ::Type{<:ParametricExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) diff --git a/src/PopMember.jl b/src/PopMember.jl index bd195a6c2..71f8707de 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -2,13 +2,32 @@ module PopMemberModule using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree +import DynamicExpressions: constructorof, with_type_parameters using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression +import ..CoreModule.OptionsModule: default_popmember_type import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_cost +""" + AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} + +Abstract type for population members. Defines the interface that all population members must implement. + +# Required fields (accessed via getproperty/setproperty!) +- `tree::N`: The expression tree +- `cost::L`: The cost including complexity penalty and normalization +- `loss::L`: The raw loss value +- `birth::Int`: Birth order/generation number +- `ref::Int`: Unique reference ID +- `parent::Int`: Parent reference ID +- `complexity::Int`: Cached complexity (accessed via getfield/setfield! for special handling) +""" +abstract type AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} end + # Define a member of population by equation, cost, and age -mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} +mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} <: + AbstractPopMember{T,L,N} tree::N cost::L # Inludes complexity penalty, normalization loss::L # Raw loss @@ -19,7 +38,9 @@ mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} ref::Int parent::Int end -@inline function Base.setproperty!(member::PopMember, field::Symbol, value) + +# Generic interface implementations for AbstractPopMember +@inline function Base.setproperty!(member::AbstractPopMember, field::Symbol, value) if field == :complexity throw( error("Don't set `.complexity` directly. Use `recompute_complexity!` instead.") @@ -34,7 +55,7 @@ end end return setfield!(member, field, value) end -@unstable @inline function Base.getproperty(member::PopMember, field::Symbol) +@unstable @inline function Base.getproperty(member::AbstractPopMember, field::Symbol) if field == :complexity throw( error("Don't access `.complexity` directly. Use `compute_complexity` instead.") @@ -145,7 +166,7 @@ function PopMember( ) end -function Base.copy(p::P) where {P<:PopMember} +function Base.copy(p::PopMember) tree = copy(p.tree) cost = copy(p.cost) loss = copy(p.loss) @@ -153,17 +174,17 @@ function Base.copy(p::P) where {P<:PopMember} complexity = copy(getfield(p, :complexity)) ref = copy(p.ref) parent = copy(p.parent) - return P(tree, cost, loss, birth, complexity, ref, parent) + return PopMember(tree, cost, loss, birth, complexity, ref, parent) end -function reset_birth!(p::PopMember; deterministic::Bool) +function reset_birth!(p::AbstractPopMember; deterministic::Bool) p.birth = get_birth_order(; deterministic) return p end # Can read off complexity directly from pop members function compute_complexity( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) + member::AbstractPopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = getfield(member, :complexity) complexity == -1 && return recompute_complexity!(member, options; break_sharing) @@ -171,11 +192,89 @@ function compute_complexity( return complexity end function recompute_complexity!( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) + member::AbstractPopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = compute_complexity(member.tree, options; break_sharing) setfield!(member, :complexity, complexity) return complexity end +""" + create_child(parent::P, tree::AbstractExpression{T}, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref) where {T,L,P<:PopMember{T,L}} + +Create a new PopMember with a potentially different expression type. +Used by embed_metadata where the expression gains metadata. +""" +function create_child( + parent::P, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, +) where {T,L,P<:PopMember{T,L}} + actual_complexity = @something complexity compute_complexity(tree, options) + return PopMember( + tree, + cost, + loss, + options, + actual_complexity; + parent=parent_ref, + deterministic=options.deterministic, + ) +end + +""" + create_child(parents::Tuple{P,P}, tree, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref) where P<:AbstractPopMember + +Create a new PopMember from two parents (crossover case). +Custom types should override to blend their additional fields. +""" +function create_child( + parents::Tuple{P,P}, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, +) where {T,L,P<:PopMember{T,L}} + actual_complexity = @something complexity compute_complexity(tree, options) + return PopMember( + tree, + cost, + loss, + options, + actual_complexity; + parent=parent_ref, + deterministic=options.deterministic, + ) +end + +# Function to extract PopMember type from Population or HallOfFame types +function popmember_type end + +@unstable default_popmember_type() = PopMember +@unstable constructorof(::Type{<:PopMember}) = PopMember + +@inline function with_expression_type( + ::Type{<:PopMember{T,L}}, ::Type{N} +) where {T,L,N<:AbstractExpression{T}} + return PopMember{T,L,N} +end + +@inline function with_type_parameters( + ::Type{<:PopMember}, ::Type{T}, ::Type{L}, ::Type{N} +) where {T,L,N} + return PopMember{T,L,N} +end + +@inline function expression_type(::Type{<:AbstractPopMember{<:Any,<:Any,N}}) where {N} + return N +end + end diff --git a/src/Population.jl b/src/Population.jl index 739ca828e..00f603258 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -2,26 +2,29 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpression, string_tree +using DynamicExpressions: AbstractExpression, string_tree, constructorof using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutationFunctionsModule: gen_random_tree -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember +import ..PopMemberModule: popmember_type using ..UtilsModule: bottomk_fast, argmin_fast, PerTaskCache # A list of members of the population, with easy constructors, # which allow for random generation of new populations -struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct Population{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} n::Int end """ - Population(pop::Array{PopMember{T,L}, 1}) + Population(pop::Array{<:AbstractPopMember, 1}) Create population from list of PopMembers. """ -function Population(pop::Vector{<:PopMember}) +function Population(pop::Vector{<:AbstractPopMember}) return Population(pop, size(pop, 1)) end @@ -41,23 +44,34 @@ function Population( npop=nothing, ) where {T,L} @assert (population_size !== nothing) ⊻ (npop !== nothing) - population_size = if npop === nothing - population_size - else - npop - end - return Population( - [ - PopMember( + population_size = something(population_size, npop) + PM = options.popmember_type + + # Create first member to get concrete type + first_member = constructorof(PM)( + dataset, + gen_random_tree(nlength, options, nfeatures, T), + options; + parent=-1, + deterministic=options.deterministic, + ) + + # Use the concrete type for the array + members = typeof(first_member)[ + if i == 1 + first_member + else + constructorof(PM)( dataset, gen_random_tree(nlength, options, nfeatures, T), options; parent=-1, deterministic=options.deterministic, - ) for _ in 1:population_size - ], - population_size, - ) + ) + end for i in 1:population_size + ] + + return Population(members, population_size) end """ Population(X::AbstractMatrix{T}, y::AbstractVector{T}; @@ -90,8 +104,8 @@ Create random population and score them on the dataset. ) end -function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}} - copied_members = Vector{PopMember{T,L,N}}(undef, pop.n) +function Base.copy(pop::P)::P where {T,L,N,PM,P<:Population{T,L,N,PM}} + copied_members = Vector{PM}(undef, pop.n) Threads.@threads for i in 1:(pop.n) copied_members[i] = copy(pop.members[i]) end @@ -118,7 +132,7 @@ function _best_of_sample( members::Vector{P}, running_search_statistics::RunningSearchStatistics, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L}} +) where {T,L,N,P<:AbstractPopMember{T,L,N}} p = options.tournament_selection_p n = length(members) # == tournament_selection_n adjusted_costs = Vector{L}(undef, n) @@ -157,7 +171,7 @@ function _best_of_sample( end return members[chosen_idx] end -_get_cost(member::PopMember) = member.cost +_get_cost(member::AbstractPopMember) = member.cost const CACHED_WEIGHTS = let init_k = collect(0:5), @@ -218,4 +232,7 @@ function record_population(pop::Population, options::AbstractOptions)::RecordTyp ) end +# Type accessor for Population +popmember_type(::Type{<:Population{T,L,N,PM}}) where {T,L,N,PM} = PM + end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 58d492e8a..05efe4c09 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -12,12 +12,18 @@ using DispatchDoctor: @unstable using Logging: AbstractLogger using DynamicExpressions: - AbstractExpression, string_tree, parse_expression, EvalOptions, with_type_parameters + AbstractExpression, + string_tree, + parse_expression, + EvalOptions, + with_type_parameters, + constructorof using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features +using ..CoreModule: + Dataset, AbstractOptions, Options, RecordType, max_features, create_expression using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -28,6 +34,15 @@ using ..CheckConstraintsModule: check_constraints function logging_callback! end +@unstable @inline function infer_popmember_type( + ::Type{T}, ::Type{L}, ::Type{D}, options +) where {T,L,D<:Dataset} + NodeType = with_type_parameters(options.node_type, T) + N = Base.promote_op(create_expression, NodeType, typeof(options), D) + N in (Any, Union{}) && error("Failed to infer expression type") + return with_type_parameters(options.popmember_type, T, L, N) +end + """ @filtered_async expr @@ -581,8 +596,9 @@ The state of the search, including the populations, worker outputs, tasks, and channels. This is used to manage the search and keep track of runtime variables in a single struct. """ -Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} <: - AbstractSearchState{T,L,N} +Base.@kwdef struct SearchState{ + T,L,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N},WorkerOutputType,ChannelType +} <: AbstractSearchState{T,L,N} procs::Vector{Int} we_created_procs::Bool worker_output::Vector{Vector{WorkerOutputType}} @@ -590,16 +606,16 @@ Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,Cha channels::Vector{Vector{ChannelType}} worker_assignment::WorkerAssignments task_order::Vector{Tuple{Int,Int}} - halls_of_fame::Vector{HallOfFame{T,L,N}} - last_pops::Vector{Vector{Population{T,L,N}}} - best_sub_pops::Vector{Vector{Population{T,L,N}}} + halls_of_fame::Vector{HallOfFame{T,L,N,PM}} + last_pops::Vector{Vector{Population{T,L,N,PM}}} + best_sub_pops::Vector{Vector{Population{T,L,N,PM}}} all_running_search_statistics::Vector{RunningSearchStatistics} num_evals::Vector{Vector{Float64}} cycles_remaining::Vector{Int} cur_maxsizes::Vector{Int} stdin_reader::StdinReader record::Base.RefValue{RecordType} - seed_members::Vector{Vector{PopMember{T,L,N}}} + seed_members::Vector{Vector{PM}} end function save_to_file( @@ -716,7 +732,7 @@ end function update_hall_of_fame!( hall_of_fame::HallOfFame, members::Vector{PM}, options::AbstractOptions -) where {PM<:PopMember} +) where {PM<:AbstractPopMember} for member in members size = compute_complexity(member, options) valid_size = 0 < size <= options.maxsize @@ -786,12 +802,22 @@ end """Parse user-provided guess expressions and convert them into optimized `PopMember` objects for each output dataset.""" -function parse_guesses( +@unstable function parse_guesses( + ::Type{P}, + guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, + datasets::Vector{D}, + options::AbstractOptions, +) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} + ConcreteP = infer_popmember_type(T, L, D, options) + return _parse_guesses_impl(ConcreteP, guesses, datasets, options) +end + +@inline function _parse_guesses_impl( ::Type{P}, guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L},D<:Dataset{T,L}} +) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} nout = length(datasets) out = [P[] for _ in 1:nout] guess_lists = _make_vector_vector(guesses, nout) @@ -799,7 +825,9 @@ function parse_guesses( dataset = datasets[j] for g in guess_lists[j] ex = _parse_guess_expression(T, g, dataset, options) - member = PopMember(dataset, ex, options; deterministic=options.deterministic) + member = constructorof(P)( + dataset, ex, options; deterministic=options.deterministic + ) if options.should_optimize_constants member, _ = optimize_constants(dataset, member, options) end @@ -817,6 +845,7 @@ function parse_guesses( end return out end + function _make_vector_vector(guesses, nout) if nout == 1 if guesses isa AbstractVector{<:AbstractVector} diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index a935a978c..4e9443f23 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -297,7 +297,9 @@ using .MutationFunctionsModule: using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func -using .PopMemberModule: PopMember, reset_birth! +using .PopMemberModule: + AbstractPopMember, PopMember, reset_birth!, popmember_type, expression_type +using .CoreModule.UtilsModule: get_birth_order using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve @@ -338,7 +340,8 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame!, parse_guesses, - logging_callback! + logging_callback!, + infer_popmember_type using .LoggingModule: AbstractSRLogger, SRLogger, get_logger using .TemplateExpressionModule: TemplateExpression, TemplateStructure, TemplateExpressionSpec, ParamVector, has_params @@ -630,11 +633,10 @@ end @recorder record["options"] = "$(options)" nout = length(datasets) - example_dataset = first(datasets) - example_ex = create_expression(init_value(T), options, example_dataset) - NT = typeof(example_ex) - PopType = Population{T,L,NT} - HallOfFameType = HallOfFame{T,L,NT} + PMType = infer_popmember_type(T, L, D, options) + NT = expression_type(PMType) + PopType = Population{T,L,NT,PMType} + HallOfFameType = HallOfFame{T,L,NT,PMType} WorkerOutputType = get_worker_output_type( Val(ropt.parallelism), PopType, HallOfFameType ) @@ -692,9 +694,9 @@ end j in 1:nout ] - seed_members = [PopMember{T,L,NT}[] for j in 1:nout] + seed_members = [Vector{PMType}() for j in 1:nout] - return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; + return SearchState{T,L,NT,PMType,WorkerOutputType,ChannelType}(; procs=procs, we_created_procs=we_created_procs, worker_output=worker_output, @@ -810,10 +812,14 @@ function _preserve_loaded_state!( options::AbstractOptions, ) where {T,L,N} nout = length(state.worker_output) + # Get the prototype to extract types + prototype_pop = state.last_pops[1][1] + PopType = typeof(prototype_pop) + PM = popmember_type(PopType) + HallType = HallOfFame{T,L,N,PM} + for j in 1:nout, i in 1:(options.populations) - (pop, _, _, _) = extract_from_worker( - state.worker_output[j][i], Population{T,L,N}, HallOfFame{T,L,N} - ) + (pop, _, _, _) = extract_from_worker(state.worker_output[j][i], PopType, HallType) state.last_pops[j][i] = copy(pop) end return nothing @@ -843,11 +849,16 @@ function _warmup_search!( # Multi-threaded doesn't like to fetch within a new task: c_rss = deepcopy(running_search_statistics) last_pop = state.worker_output[j][i] + + # Get the prototype to extract types + prototype_pop = state.last_pops[j][i] + PopType = typeof(prototype_pop) + PM = popmember_type(PopType) + HallType = HallOfFame{T,L,N,PM} + updated_pop = @sr_spawner( begin - in_pop = first( - extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N}) - ) + in_pop = first(extract_from_worker(last_pop, PopType, HallType)) _dispatch_s_r_cycle( in_pop, dataset, diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 201c047e7..7562208c6 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -52,7 +52,7 @@ using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..ComposableExpressionModule: ComposableExpression, ValidVector struct ParamVector{T} <: AbstractVector{T} @@ -745,7 +745,7 @@ function MM.condition_mutation_weights!( @nospecialize(options::AbstractOptions), curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:TemplateExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:TemplateExpression,P<:AbstractPopMember{T,L,N}} if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 weights.break_connection = 0.0 @@ -828,7 +828,7 @@ end function MM.condition_mutate_constant!( ::Type{<:TemplateExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) diff --git a/test/runtests.jl b/test/runtests.jl index 7aef02fea..f4bd81111 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -164,6 +164,7 @@ end end include("test_abstract_numbers.jl") +include("test_abstract_popmember.jl") include("test_logging.jl") include("test_pretty_printing.jl") diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl new file mode 100644 index 000000000..eabc6886d --- /dev/null +++ b/test/test_abstract_popmember.jl @@ -0,0 +1,192 @@ +@testitem "Custom AbstractPopMember implementation" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions + using Test + using DispatchDoctor: @unstable + + import SymbolicRegression.PopMemberModule: create_child + + # Define a custom PopMember that tracks generation count + mutable struct CustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N} + tree::N + cost::L + loss::L + birth::Int + complexity::Int + ref::Int + parent::Int + generation::Int # Custom field to track generation + end + + # # Direct constructor that matches field order + function CustomPopMember( + tree::N, + cost::L, + loss::L, + birth::Int, + complexity::Int, + ref::Int, + parent::Int, + generation::Int, + ) where {T,L,N<:AbstractExpression{T}} + return CustomPopMember{T,L,N}( + tree, cost, loss, birth, complexity, ref, parent, generation + ) + end + + function CustomPopMember( + tree::N, + cost::L, + loss::L, + options, + complexity::Int; + parent=-1, + deterministic=nothing, + ) where {T,L,N<:AbstractExpression{T}} + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + 0, # Initial generation + ) + end + + # Constructor for Population initialization (dataset, tree, options) + function CustomPopMember( + dataset::SymbolicRegression.Dataset, tree, options; parent=-1, deterministic=nothing + ) + ex = SymbolicRegression.create_expression(tree, options, dataset) + complexity = SymbolicRegression.compute_complexity(ex, options) + cost, loss = SymbolicRegression.eval_cost( + dataset, ex, options; complexity=complexity + ) + + return CustomPopMember( + ex, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + 0, # Initial generation + ) + end + + @unstable DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember + + # Define with_type_parameters for CustomPopMember + @unstable function DynamicExpressions.with_type_parameters( + ::Type{<:CustomPopMember}, ::Type{T}, ::Type{L}, ::Type{N} + ) where {T,L,N} + return CustomPopMember{T,L,N} + end + + # Define copy for CustomPopMember + function Base.copy(p::CustomPopMember) + return CustomPopMember( + copy(p.tree), + copy(p.cost), + copy(p.loss), + copy(p.birth), + copy(getfield(p, :complexity)), + copy(p.ref), + copy(p.parent), + copy(p.generation), + ) + end + + function create_child( + parent::CustomPopMember{T,L}, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + ) where {T,L} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=options.deterministic), + actual_complexity, + abs(rand(Int)), + parent_ref, + parent.generation + 1, + ) + end + + function create_child( + parents::Tuple{<:CustomPopMember,<:CustomPopMember}, + tree::N, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + ) where {T,L,N<:AbstractExpression{T}} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + max_generation = max(parents[1].generation, parents[2].generation) + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.CoreModule.UtilsModule.get_birth_order(; + deterministic=options.deterministic + ), + actual_complexity, + abs(rand(Int)), + parent_ref, + max_generation + 1, + ) + end + + # Test that we can run equation_search with CustomPopMember + X = randn(Float32, 2, 100) + y = @. X[1, :]^2 - X[2, :] + + options = SymbolicRegression.Options(; + binary_operators=[+, -], + populations=1, + population_size=20, + maxsize=5, + popmember_type=CustomPopMember, + deterministic=true, + seed=0, + ) + + # Test that options were created with correct type + @test options.popmember_type == CustomPopMember + + hall_of_fame = equation_search( + X, y; options=options, niterations=2, parallelism=:serial + ) + + # Verify that we got results + @test sum(hall_of_fame.exists) > 0 + + # Verify that the members are CustomPopMember + for i in eachindex(hall_of_fame.members, hall_of_fame.exists) + if hall_of_fame.exists[i] + @test hall_of_fame.members[i] isa CustomPopMember + # Check that generation field exists + @test hall_of_fame.members[i].generation >= 0 + end + end + + # Verify we can extract the best member + best_idx = findlast(hall_of_fame.exists) + @test !isnothing(best_idx) + best_member = hall_of_fame.members[best_idx] + @test best_member isa CustomPopMember +end