diff --git a/Project.toml b/Project.toml index f5a02e829..3d6620ace 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicRegression" uuid = "8254be44-1295-4e6a-a16d-46603ac705cb" authors = ["MilesCranmer "] -version = "2.0.0-alpha.8" +version = "2.0.0-alpha.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index e92149614..3412242b5 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -88,10 +88,10 @@ function create_utils_benchmark() suite["best_of_sample"] = @benchmarkable( best_of_sample(pop, rss, $options), setup = ( - nfeatures=1; - dataset=Dataset(randn(nfeatures, 32), randn(32)); - pop=Population(dataset; npop=100, nlength=20, options=($options), nfeatures); - rss=RunningSearchStatistics(; options=($options)) + nfeatures = 1; + dataset = Dataset(randn(nfeatures, 32), randn(32)); + pop = Population(dataset; npop=100, nlength=20, options=($options), nfeatures); + rss = RunningSearchStatistics(; options=($options)) ) ) @@ -110,9 +110,9 @@ function create_utils_benchmark() end end, setup = ( - nfeatures=1; - dataset=Dataset(randn(nfeatures, 32), randn(32)); - mutation_weights=MutationWeights(; + nfeatures = 1; + dataset = Dataset(randn(nfeatures, 32), randn(32)); + mutation_weights = MutationWeights(; mutate_constant=1.0, mutate_operator=1.0, swap_operands=1.0, @@ -125,23 +125,21 @@ function create_utils_benchmark() form_connection=0.0, break_connection=0.0, ); - options=Options(; - unary_operators=[sin, cos], - binary_operators=[+, -, *, /], - mutation_weights, + options = Options(; + unary_operators=[sin, cos], binary_operators=[+, -, *, /], mutation_weights ); - recorder=RecordType(); - temperature=1.0; - curmaxsize=20; - rss=RunningSearchStatistics(; options); - trees=[ + recorder = RecordType(); + temperature = 1.0; + curmaxsize = 20; + rss = RunningSearchStatistics(; options); + trees = [ gen_random_tree_fixed_size(15, options, nfeatures, Float64) for _ in 1:100 ]; - expressions=[ + expressions = [ Expression(tree; operators=options.operators, variable_names=["x1"]) for tree in trees ]; - members=[ + members = [ PopMember(dataset, expression, options; deterministic=false) for expression in expressions ] @@ -155,14 +153,14 @@ function create_utils_benchmark() end, seconds = 20, setup = ( - nfeatures=1; - T=Float64; - dataset=Dataset(randn(nfeatures, 512), randn(512)); - ntrees=($ntrees); - trees=[ + nfeatures = 1; + T = Float64; + dataset = Dataset(randn(nfeatures, 512), randn(512)); + ntrees = ($ntrees); + trees = [ gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:ntrees ]; - members=[ + members = [ PopMember(dataset, tree, $options; deterministic=false) for tree in trees ] ) @@ -181,9 +179,9 @@ function create_utils_benchmark() compute_complexity(tree, $options) end, setup = ( - T=Float64; - nfeatures=3; - trees=[ + T = Float64; + nfeatures = 3; + trees = [ gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees) ] @@ -199,9 +197,9 @@ function create_utils_benchmark() SymbolicRegression.MutationFunctionsModule.randomly_rotate_tree!(tree) end, setup = ( - T=Float64; - nfeatures=3; - trees=[ + T = Float64; + nfeatures = 3; + trees = [ gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees) ] @@ -216,9 +214,9 @@ function create_utils_benchmark() ) end, setup = ( - T=Float64; - nfeatures=3; - trees=[ + T = Float64; + nfeatures = 3; + trees = [ gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees) ] ) @@ -242,9 +240,9 @@ function create_utils_benchmark() check_constraints(tree, $options, $options.maxsize) end, setup = ( - T=Float64; - nfeatures=3; - trees=[ + T = Float64; + nfeatures = 3; + trees = [ gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees) ] ) diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl index 72669983c..b612e6766 100644 --- a/src/CheckConstraints.jl +++ b/src/CheckConstraints.jl @@ -91,8 +91,8 @@ function check_constraints( return true end -check_constraints(ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions)::Bool = check_constraints( - ex, options, options.maxsize -) +check_constraints( + ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions +)::Bool = check_constraints(ex, options, options.maxsize) end diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 0666e1bb5..7e3ee3dd0 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -271,9 +271,8 @@ function apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N} if all(_is_valid, x) return _apply_operator(op, x...) else - example_vector = something( - map(xi -> xi isa ValidVector ? xi : nothing, x)... - )::ValidVector + example_vector = + something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector expected_return_type = Base.promote_op( _apply_operator, typeof(op), map(typeof, x)... ) diff --git a/src/Configure.jl b/src/Configure.jl index e637c24c6..e1c04c7e5 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -243,10 +243,12 @@ function activate_env_on_workers( ) verbosity > 0 && @info "Activating environment on workers." @everywhere procs begin - Base.MainInclude.eval(quote - using Pkg - Pkg.activate($$project_path) - end) + Base.MainInclude.eval( + quote + using Pkg + Pkg.activate($$project_path) + end, + ) end end @@ -289,9 +291,12 @@ function import_module_on_workers( all_extensions = vcat(relevant_extensions, @something(worker_imports, Symbol[])) for ext in all_extensions - push!(expr.args, quote - using $ext: $ext - end) + push!( + expr.args, + quote + using $ext: $ext + end, + ) end verbosity > 0 && if isempty(relevant_extensions) diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl index 585a6ef78..1c6bb01f1 100644 --- a/src/ConstantOptimization.jl +++ b/src/ConstantOptimization.jl @@ -17,7 +17,7 @@ using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options, dataset_fraction using ..UtilsModule: get_birth_order, PerTaskCache, stable_get! using ..LossFunctionsModule: eval_loss, loss_to_cost -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember function can_optimize(::AbstractExpression{T}, options) where {T} return can_optimize(T, options) @@ -31,7 +31,7 @@ end member::P, options::AbstractOptions; rng::AbstractRNG=default_rng(), -)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:AbstractPopMember{T,L}} can_optimize(member.tree, options) || return (member, 0.0) nconst = count_constants_for_optimization(member.tree) nconst == 0 && return (member, 0.0) @@ -63,7 +63,7 @@ count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex) function _optimize_constants( dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T,L,P<:AbstractPopMember{T,L}} tree = member.tree x0, refs = get_scalar_constants(tree) @assert count_constants_for_optimization(tree) == length(x0) @@ -76,7 +76,7 @@ function _optimize_constants( end function _optimize_constants_inner( f::F, fg!::G, x0, refs, dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {F,G,T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {F,G,T,L,P<:AbstractPopMember{T,L}} obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing f else diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 4de4e56a2..61caf3d0e 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -11,7 +11,7 @@ using DynamicExpressions: using ..CoreModule: AbstractOptions, Dataset using ..HallOfFameModule: HallOfFame using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember import DynamicExpressions: get_operators import ..CoreModule: create_expression @@ -134,8 +134,10 @@ end ) end function embed_metadata( - vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} + vec::Vector{<:Union{HallOfFame,Population,AbstractPopMember}}, + options::AbstractOptions, + dataset::Dataset{T,L}, + ) where {T,L} return map(Fix{2}(Fix{3}(embed_metadata, dataset), options), vec) end end @@ -153,8 +155,8 @@ function strip_metadata( return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) end function strip_metadata( - member::PopMember, options::AbstractOptions, dataset::Dataset{T,L} -) where {T,L} + member::PM, options::AbstractOptions, dataset::Dataset{T,L} +) where {T,L,PM<:PopMember{T,L}} return PopMember( strip_metadata(member.tree, options, dataset), member.cost, @@ -165,14 +167,14 @@ function strip_metadata( deterministic=options.deterministic, ) end -function strip_metadata( - pop::Population, options::AbstractOptions, dataset::Dataset{T,L} -) where {T,L} +@unstable function strip_metadata( + pop::P, options::AbstractOptions, dataset::Dataset{T,L} +) where {T,L,P<:Population{T,L}} return Population(map(member -> strip_metadata(member, options, dataset), pop.members)) end function strip_metadata( - hof::HallOfFame, options::AbstractOptions, dataset::Dataset{T,L} -) where {T,L} + hof::H, options::AbstractOptions, dataset::Dataset{T,L} +) where {T,L,N,PM<:AbstractPopMember,H<:HallOfFame{T,L,N,PM}} return HallOfFame( map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists ) diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index c90fbe3a4..474eadc40 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -6,7 +6,7 @@ using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression, init_value using ..ComplexityModule: compute_complexity -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING using Printf: @sprintf @@ -23,8 +23,10 @@ have been set, you can run `.members[exists]`. These are ordered by complexity, with `.members[1]` the member with complexity 1. - `exists::Array{Bool,1}`: Whether the member at the given complexity has been set. """ -struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct HallOfFame{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} exists::Array{Bool,1} #Whether it has been set end function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N} @@ -69,7 +71,7 @@ function HallOfFame( ) where {T<:DATA_TYPE,L<:LOSS_TYPE} base_tree = create_expression(init_value(T), options, dataset) - return HallOfFame{T,L,typeof(base_tree)}( + return HallOfFame{T,L,typeof(base_tree),PopMember{T,L,typeof(base_tree)}}( [ PopMember( copy(base_tree), @@ -95,9 +97,8 @@ end """ function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} # TODO - remove dataset from args. - P = PopMember{T,L,N} # Dominating pareto curve - must be better than all simpler equations - dominating = P[] + dominating = similar(hallOfFame.members, 0) for size in eachindex(hallOfFame.members) if !hallOfFame.exists[size] continue diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index bb2fddfb0..2f6e441fc 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -595,9 +595,9 @@ end function get_equation_strings_for( ::AbstractSingletargetSRRegressor, trees, options, variable_names ) - return (t -> string_tree(t, options; variable_names=variable_names, pretty=false)).( - trees - ) + return ( + t -> string_tree(t, options; variable_names=variable_names, pretty=false) + ).(trees) end function get_equation_strings_for( ::AbstractMultitargetSRRegressor, trees, options, variable_names diff --git a/src/Migration.jl b/src/Migration.jl index f7fe61b89..d2446954f 100644 --- a/src/Migration.jl +++ b/src/Migration.jl @@ -2,7 +2,7 @@ module MigrationModule using ..CoreModule: AbstractOptions using ..PopulationModule: Population -using ..PopMemberModule: PopMember, reset_birth! +using ..PopMemberModule: AbstractPopMember, reset_birth! using ..UtilsModule: poisson_sample """ @@ -14,7 +14,7 @@ to do so. The original migrant population is not modified. Pass with, e.g., """ function migrate!( migration::Pair{Vector{PM},P}, options::AbstractOptions; frac::AbstractFloat -) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}} +) where {T,L,N,PM<:AbstractPopMember{T,L,N},P<:Population{T,L,N}} base_pop = migration.second population_size = length(base_pop.members) mean_number_replaced = population_size * frac diff --git a/src/Mutate.jl b/src/Mutate.jl index f5fd88457..480d550ed 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -22,7 +22,7 @@ using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember using ..MutationFunctionsModule: mutate_constant, mutate_operator, @@ -40,10 +40,10 @@ using ..MutationFunctionsModule: using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder -abstract type AbstractMutationResult{N<:AbstractExpression,P<:PopMember} end +abstract type AbstractMutationResult{N<:AbstractExpression,P<:AbstractPopMember} end """ - MutationResult{N<:AbstractExpression,P<:PopMember} + MutationResult{N<:AbstractExpression,P<:AbstractPopMember} Represents the result of a mutation operation in the genetic programming algorithm. This struct is used to return values from `mutate!` functions. @@ -61,7 +61,8 @@ This struct encapsulates the result of a mutation operation. Either a new expres Return the `member` if you want to return immediately, and have computed the loss value as part of the mutation. """ -struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationResult{N,P} +struct MutationResult{N<:AbstractExpression,P<:AbstractPopMember} <: + AbstractMutationResult{N,P} tree::Union{N,Nothing} member::Union{P,Nothing} num_evals::Float64 @@ -73,7 +74,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes member::Union{_P,Nothing}=nothing, num_evals::Float64=0.0, return_immediately::Bool=false, - ) where {_N<:AbstractExpression,_P<:PopMember} + ) where {_N<:AbstractExpression,_P<:AbstractPopMember} @assert( (tree === nothing) ⊻ (member === nothing), "Mutation result must return either a tree or a pop member, not both" @@ -104,7 +105,7 @@ function condition_mutation_weights!( options::AbstractOptions, curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:AbstractExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:AbstractExpression,P<:AbstractPopMember{T,L,N}} tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 @@ -159,10 +160,10 @@ Use this to modify how `mutate_constant` changes for an expression type. function condition_mutate_constant!( ::Type{<:AbstractExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::PM, options::AbstractOptions, curmaxsize::Int, -) +) where {PM<:AbstractPopMember} n_constants = count_scalar_constants(member.tree) weights.mutate_constant *= min(8, n_constants) / 8.0 @@ -181,7 +182,7 @@ end tmp_recorder::RecordType, )::Tuple{ P,Bool,Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} parent_ref = member.ref num_evals = 0.0 @@ -357,12 +358,12 @@ end @generated function _dispatch_mutations!( tree::AbstractExpression, - member::PopMember, + member::PM, mutation_choice::Symbol, weights::W, options::AbstractOptions; kws..., -) where {W<:AbstractMutationWeights} +) where {W<:AbstractMutationWeights,PM<:AbstractPopMember} mutation_choices = fieldnames(W) quote Base.Cartesian.@nif( @@ -386,7 +387,7 @@ end mutation_weights::AbstractMutationWeights, options::AbstractOptions; kws..., - ) where {N<:AbstractExpression,P<:PopMember,S} + ) where {N<:AbstractExpression,P<:AbstractPopMember,S} Perform a mutation on the given `tree` and `member` using the specified mutation type `S`. Various `kws` are provided to access other data needed for some mutations. @@ -414,7 +415,7 @@ so it can always return immediately. """ function mutate!( ::N, ::P, ::Val{S}, ::AbstractMutationWeights, ::AbstractOptions; kws... -) where {N<:AbstractExpression,P<:PopMember,S} +) where {N<:AbstractExpression,P<:AbstractPopMember,S} return error("Unknown mutation choice: $S") end @@ -427,7 +428,7 @@ function mutate!( recorder::RecordType, temperature, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_constant(tree, temperature, options) @recorder recorder["type"] = "mutate_constant" return MutationResult{N,P}(; tree=tree) @@ -441,7 +442,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_operator(tree, options) @recorder recorder["type"] = "mutate_operator" return MutationResult{N,P}(; tree=tree) @@ -456,7 +457,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_feature(tree, nfeatures) @recorder recorder["type"] = "mutate_feature" return MutationResult{N,P}(; tree=tree) @@ -470,7 +471,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = swap_operands(tree) @recorder recorder["type"] = "swap_operands" return MutationResult{N,P}(; tree=tree) @@ -485,7 +486,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} if rand() < 0.5 tree = append_random_op(tree, options, nfeatures) @recorder recorder["type"] = "add_node:append" @@ -505,7 +506,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = insert_random_op(tree, options, nfeatures) @recorder recorder["type"] = "insert_node" return MutationResult{N,P}(; tree=tree) @@ -519,7 +520,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = delete_random_op!(tree) @recorder recorder["type"] = "delete_node" return MutationResult{N,P}(; tree=tree) @@ -533,7 +534,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = form_random_connection!(tree) @recorder recorder["type"] = "form_connection" return MutationResult{N,P}(; tree=tree) @@ -547,7 +548,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = break_random_connection!(tree) @recorder recorder["type"] = "break_connection" return MutationResult{N,P}(; tree=tree) @@ -561,7 +562,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = randomly_rotate_tree!(tree) @recorder recorder["type"] = "rotate_tree" return MutationResult{N,P}(; tree=tree) @@ -577,7 +578,7 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @assert options.should_simplify simplify_tree!(tree, options.operators) tree = combine_operators(tree, options.operators) @@ -605,7 +606,7 @@ function mutate!( curmaxsize, nfeatures, kws..., -) where {T,N<:AbstractExpression{T},P<:PopMember} +) where {T,N<:AbstractExpression{T},P<:AbstractPopMember} tree = randomize_tree(tree, curmaxsize, options, nfeatures) @recorder recorder["type"] = "randomize" return MutationResult{N,P}(; tree=tree) @@ -620,7 +621,7 @@ function mutate!( recorder::RecordType, dataset::Dataset, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} cur_member, new_num_evals = optimize_constants(dataset, member, options) @recorder recorder["type"] = "optimize" return MutationResult{N,P}(; @@ -637,7 +638,7 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @recorder begin recorder["type"] = "identity" recorder["result"] = "accept" @@ -665,7 +666,7 @@ function crossover_generation( curmaxsize::Int, options::AbstractOptions; recorder::RecordType=RecordType(), -)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} +)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:AbstractPopMember{T,L,N}} tree1 = member1.tree tree2 = member2.tree crossover_accepted = false diff --git a/src/Options.jl b/src/Options.jl index cb8d3bbf3..cde0d5a31 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -150,8 +150,7 @@ end break end end - found_degree == 0 && - error("Operator $(op) is not in the operator set.") + found_degree == 0 && error("Operator $(op) is not in the operator set.") (found_degree, found_idx) end, new_max_nesting_dict = [ @@ -167,7 +166,7 @@ end end end found_degree == 0 && - error("Operator $(nested_op) is not in the operator set.") + error("Operator $(nested_op) is not in the operator set.") (found_degree, found_idx) end (nested_degree, nested_idx, max_nesting) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 2717afbdc..185753b8a 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -24,7 +24,7 @@ using ..CoreModule: AbstractExpressionSpec, get_indices, ExpressionSpecModule as ES -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..LossFunctionsModule: LossFunctionsModule as LF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB @@ -102,10 +102,10 @@ end function MM.condition_mutate_constant!( ::Type{<:ParametricExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::PM, options::AbstractOptions, curmaxsize::Int, -) +) where {PM<:AbstractPopMember} # Avoid modifying the mutate_constant weight, since # otherwise we would be mutating constants all the time! return nothing diff --git a/src/PopMember.jl b/src/PopMember.jl index bd195a6c2..f7e6763ef 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -7,8 +7,11 @@ import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_cost +abstract type AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} end + # Define a member of population by equation, cost, and age -mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} +mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} <: + AbstractPopMember{T,L,N} tree::N cost::L # Inludes complexity penalty, normalization loss::L # Raw loss @@ -19,7 +22,9 @@ mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} ref::Int parent::Int end -@inline function Base.setproperty!(member::PopMember, field::Symbol, value) +@inline function Base.setproperty!( + member::M, field::Symbol, value +) where {M<:AbstractPopMember} if field == :complexity throw( error("Don't set `.complexity` directly. Use `recompute_complexity!` instead.") @@ -34,7 +39,9 @@ end end return setfield!(member, field, value) end -@unstable @inline function Base.getproperty(member::PopMember, field::Symbol) +@unstable @inline function Base.getproperty( + member::M, field::Symbol +) where {M<:AbstractPopMember} if field == :complexity throw( error("Don't access `.complexity` directly. Use `compute_complexity` instead.") @@ -47,7 +54,7 @@ end end return getfield(member, field) end -function Base.show(io::IO, p::PopMember{T,L,N}) where {T,L,N} +function Base.show(io::IO, p::PM) where {T,L,N,PM<:PopMember{T,L,N}} shower(x) = sprint(show, x) print(io, "PopMember(") print(io, "tree = (", string_tree(p.tree), "), ") @@ -145,7 +152,7 @@ function PopMember( ) end -function Base.copy(p::P) where {P<:PopMember} +function Base.copy(p::P) where {P<:AbstractPopMember} tree = copy(p.tree) cost = copy(p.cost) loss = copy(p.loss) @@ -156,23 +163,23 @@ function Base.copy(p::P) where {P<:PopMember} return P(tree, cost, loss, birth, complexity, ref, parent) end -function reset_birth!(p::PopMember; deterministic::Bool) +function reset_birth!(p::M; deterministic::Bool) where {M<:AbstractPopMember} p.birth = get_birth_order(; deterministic) return p end # Can read off complexity directly from pop members function compute_complexity( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) -)::Int + member::M, options::AbstractOptions; break_sharing=Val(false) +)::Int where {M<:AbstractPopMember} complexity = getfield(member, :complexity) complexity == -1 && return recompute_complexity!(member, options; break_sharing) # TODO: Turn this into a warning, and then return normal compute_complexity instead. return complexity end function recompute_complexity!( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) -)::Int + member::M, options::AbstractOptions; break_sharing=Val(false) +)::Int where {M<:AbstractPopMember} complexity = compute_complexity(member.tree, options; break_sharing) setfield!(member, :complexity, complexity) return complexity diff --git a/src/Population.jl b/src/Population.jl index 739ca828e..a9abaf272 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -8,12 +8,12 @@ using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutationFunctionsModule: gen_random_tree -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..UtilsModule: bottomk_fast, argmin_fast, PerTaskCache # A list of members of the population, with easy constructors, # which allow for random generation of new populations struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} + members::Array{<:AbstractPopMember{T,L,N},1} n::Int end """ @@ -21,7 +21,7 @@ end Create population from list of PopMembers. """ -function Population(pop::Vector{<:PopMember}) +function Population(pop::Vector{<:AbstractPopMember{T,L}}) where {T<:DATA_TYPE,L<:LOSS_TYPE} return Population(pop, size(pop, 1)) end @@ -91,7 +91,12 @@ Create random population and score them on the dataset. end function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}} - copied_members = Vector{PopMember{T,L,N}}(undef, pop.n) + PM = if length(pop.members) > 0 + typeof(pop.members[1]) + else + AbstractPopMember{T,L,N} + end + copied_members = Vector{PM}(undef, pop.n) Threads.@threads for i in 1:(pop.n) copied_members[i] = copy(pop.members[i]) end @@ -118,7 +123,7 @@ function _best_of_sample( members::Vector{P}, running_search_statistics::RunningSearchStatistics, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L}} +) where {T,L,P<:AbstractPopMember{T,L}} p = options.tournament_selection_p n = length(members) # == tournament_selection_n adjusted_costs = Vector{L}(undef, n) @@ -157,7 +162,7 @@ function _best_of_sample( end return members[chosen_idx] end -_get_cost(member::PopMember) = member.cost +_get_cost(member::AbstractPopMember) = member.cost const CACHED_WEIGHTS = let init_k = collect(0:5), @@ -201,19 +206,21 @@ function best_sub_pop(pop::P; topn::Int=10)::P where {P<:Population} return Population(pop.members[best_idx[1:topn]]) end +function generate_record(member::PopMember, options::AbstractOptions)::RecordType + return RecordType( + "tree" => string_tree(member.tree, options; pretty=false), + "loss" => member.loss, + "cost" => member.cost, + "complexity" => compute_complexity(member, options), + "birth" => member.birth, + "ref" => member.ref, + "parent" => member.parent, + ) +end + function record_population(pop::Population, options::AbstractOptions)::RecordType return RecordType( - "population" => [ - RecordType( - "tree" => string_tree(member.tree, options; pretty=false), - "loss" => member.loss, - "cost" => member.cost, - "complexity" => compute_complexity(member, options), - "birth" => member.birth, - "ref" => member.ref, - "parent" => member.parent, - ) for member in pop.members - ], + "population" => [generate_record(member, options) for member in pop.members], "time" => time(), ) end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 58d492e8a..e568bfd7c 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -17,7 +17,7 @@ using ..UtilsModule: subscriptify using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -37,16 +37,18 @@ to avoid spam when worker processes exit normally. macro filtered_async(expr) return esc( quote - $(Base).errormonitor(@async begin - try - $expr - catch ex - if !(ex isa $(Distributed).ProcessExitedException) - rethrow(ex) + $(Base).errormonitor( + @async begin + try + $expr + catch ex + if !(ex isa $(Distributed).ProcessExitedException) + rethrow(ex) + end end end - end) - end + ) + end, ) end @@ -281,9 +283,9 @@ function get_worker_output_type( end #! format: off -extract_from_worker(p::DefaultWorkerOutputType, _, _) = p -extract_from_worker(f::Future, ::Type{P}, ::Type{H}) where {P,H} = fetch(f)::DefaultWorkerOutputType{P,H} -extract_from_worker(t::Task, ::Type{P}, ::Type{H}) where {P,H} = fetch(t)::DefaultWorkerOutputType{P,H} +@unstable extract_from_worker(p::DefaultWorkerOutputType, _, _) = p +@unstable extract_from_worker(f::Future, ::Type{P}, ::Type{H}) where {P,H} = fetch(f)::DefaultWorkerOutputType{P,H} +@unstable extract_from_worker(t::Task, ::Type{P}, ::Type{H}) where {P,H} = fetch(t)::DefaultWorkerOutputType{P,H} #! format: on macro sr_spawner(expr, kws...) @@ -716,7 +718,7 @@ end function update_hall_of_fame!( hall_of_fame::HallOfFame, members::Vector{PM}, options::AbstractOptions -) where {PM<:PopMember} +) where {PM<:AbstractPopMember} for member in members size = compute_complexity(member, options) valid_size = 0 < size <= options.maxsize @@ -791,7 +793,7 @@ function parse_guesses( guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L},D<:Dataset{T,L}} +) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} nout = length(datasets) out = [P[] for _ in 1:nout] guess_lists = _make_vector_vector(guesses, nout) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index a935a978c..6871fa955 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -2,6 +2,7 @@ module SymbolicRegression # Types export Population, + AbstractPopMember, PopMember, HallOfFame, Options, @@ -297,7 +298,7 @@ using .MutationFunctionsModule: using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func -using .PopMemberModule: PopMember, reset_birth! +using .PopMemberModule: AbstractPopMember, PopMember, reset_birth! using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve @@ -582,10 +583,10 @@ end guesses, ) where {D<:Dataset} _validate_options(datasets, ropt, options) - state = _create_workers(datasets, ropt, options) + state = _create_workers(PopMember, datasets, ropt, options) _initialize_search!(state, datasets, ropt, options, saved_state, guesses) - _warmup_search!(state, datasets, ropt, options) - _main_search_loop!(state, datasets, ropt, options) + _warmup_search!(PopMember, state, datasets, ropt, options) + _main_search_loop!(PopMember, state, datasets, ropt, options) _tear_down!(state, ropt, options) _info_dump(state, datasets, ropt, options) return _format_output(state, datasets, ropt, options) @@ -622,8 +623,8 @@ function _validate_options( return nothing end @stable default_mode = "disable" function _create_workers( - datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions -) where {T,L,D<:Dataset{T,L}} + ::Type{PM}, datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions +) where {T,L,D<:Dataset{T,L},PM<:AbstractPopMember} stdin_reader = watch_stream(options.input_stream) record = RecordType() @@ -634,7 +635,7 @@ end example_ex = create_expression(init_value(T), options, example_dataset) NT = typeof(example_ex) PopType = Population{T,L,NT} - HallOfFameType = HallOfFame{T,L,NT} + HallOfFameType = HallOfFame{T,L,NT,PM{T,L,NT}} WorkerOutputType = get_worker_output_type( Val(ropt.parallelism), PopType, HallOfFameType ) @@ -820,11 +821,12 @@ function _preserve_loaded_state!( end function _warmup_search!( + ::Type{PM}, state::AbstractSearchState{T,L,N}, datasets, ropt::AbstractRuntimeOptions, options::AbstractOptions, -) where {T,L,N} +) where {T,L,N,PM<:AbstractPopMember} if ropt.niterations == 0 return _preserve_loaded_state!(state, ropt, options) end @@ -846,7 +848,9 @@ function _warmup_search!( updated_pop = @sr_spawner( begin in_pop = first( - extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N}) + extract_from_worker( + last_pop, Population{T,L,N}, HallOfFame{T,L,N,PM{T,L,N}} + ), ) _dispatch_s_r_cycle( in_pop, @@ -858,7 +862,7 @@ function _warmup_search!( ropt.verbosity, cur_maxsize, running_search_statistics=c_rss, - )::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}} + )::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N,PM{T,L,N}}} end, parallelism = ropt.parallelism, worker_idx = worker_idx @@ -868,11 +872,12 @@ function _warmup_search!( return nothing end function _main_search_loop!( + ::Type{PM}, state::AbstractSearchState{T,L,N}, datasets, ropt::AbstractRuntimeOptions, options::AbstractOptions, -) where {T,L,N} +) where {T,L,N,PM<:AbstractPopMember} ropt.verbosity > 0 && @info "Started!" nout = length(datasets) start_time = time() @@ -946,7 +951,7 @@ function _main_search_loop!( ) else state.worker_output[j][i] - end::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}} + end::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N,PM{T,L,N}}} state.last_pops[j][i] = copy(cur_pop) state.best_sub_pops[j][i] = best_sub_pop(cur_pop; topn=options.topn) @recorder state.record[] = recursive_merge(state.record[], cur_record) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 201c047e7..473ca1e67 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -52,7 +52,7 @@ using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..ComposableExpressionModule: ComposableExpression, ValidVector struct ParamVector{T} <: AbstractVector{T} @@ -745,7 +745,7 @@ function MM.condition_mutation_weights!( @nospecialize(options::AbstractOptions), curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:TemplateExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:TemplateExpression,P<:AbstractPopMember{T,L,N}} if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 weights.break_connection = 0.0 @@ -828,10 +828,10 @@ end function MM.condition_mutate_constant!( ::Type{<:TemplateExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::PM, options::AbstractOptions, curmaxsize::Int, -) +) where {PM<:AbstractPopMember} # Avoid modifying the mutate_constant weight, since # otherwise we would be mutating constants all the time! return nothing diff --git a/test/manual_distributed.jl b/test/manual_distributed.jl index e005155fd..ae3c52bb7 100644 --- a/test/manual_distributed.jl +++ b/test/manual_distributed.jl @@ -5,10 +5,12 @@ procs = addprocs(2) using Test, Pkg project_path = splitdir(Pkg.project().path)[1] @everywhere procs begin - Base.MainInclude.eval(quote - using Pkg - Pkg.activate($$project_path) - end) + Base.MainInclude.eval( + quote + using Pkg + Pkg.activate($$project_path) + end, + ) end @everywhere using SymbolicRegression @everywhere _inv(x::Float32)::Float32 = 1.0f0 / x diff --git a/test/test_units.jl b/test/test_units.jl index dd36ca6c9..05a2575e7 100644 --- a/test/test_units.jl +++ b/test/test_units.jl @@ -137,9 +137,9 @@ end !has_cos(member.tree) || any( t -> t.degree == 1 && - options.operators.unaops[t.op] == cos && - Node(Float64; feature=1) in t && - compute_complexity(t, options) > 1, + options.operators.unaops[t.op] == cos && + Node(Float64; feature=1) in t && + compute_complexity(t, options) > 1, get_tree(member.tree), ) for member in dominating ] @@ -432,6 +432,7 @@ end @testitem "Miscellaneous tests of unit interface" tags = [:part3] begin using SymbolicRegression using DynamicQuantities + using MLJBase using SymbolicRegression.DimensionalAnalysisModule: @maybe_return_call, WildcardQuantity using SymbolicRegression.MLJInterfaceModule: unwrap_units_single using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_dimensions_type diff --git a/test/test_utils.jl b/test/test_utils.jl index 2433476a1..67ceb0dcc 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -9,8 +9,8 @@ function simple_bottomk(x, k) end array_options = [ - (n=n, seed=seed, T=T) for - n in (1, 5, 20, 50, 100, 1000), seed in 1:10, T in (Float32, Float64, Int) + (n=n, seed=seed, T=T) for n in (1, 5, 20, 50, 100, 1000), seed in 1:10, + T in (Float32, Float64, Int) ] @testset "argmin_fast" begin