Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <[email protected]>"]
version = "2.0.0-alpha.8"
version = "2.0.0-alpha.9"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
70 changes: 34 additions & 36 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ function create_utils_benchmark()
suite["best_of_sample"] = @benchmarkable(
best_of_sample(pop, rss, $options),
setup = (
nfeatures=1;
dataset=Dataset(randn(nfeatures, 32), randn(32));
pop=Population(dataset; npop=100, nlength=20, options=($options), nfeatures);
rss=RunningSearchStatistics(; options=($options))
nfeatures = 1;
dataset = Dataset(randn(nfeatures, 32), randn(32));
pop = Population(dataset; npop=100, nlength=20, options=($options), nfeatures);
rss = RunningSearchStatistics(; options=($options))
)
)

Expand All @@ -110,9 +110,9 @@ function create_utils_benchmark()
end
end,
setup = (
nfeatures=1;
dataset=Dataset(randn(nfeatures, 32), randn(32));
mutation_weights=MutationWeights(;
nfeatures = 1;
dataset = Dataset(randn(nfeatures, 32), randn(32));
mutation_weights = MutationWeights(;
mutate_constant=1.0,
mutate_operator=1.0,
swap_operands=1.0,
Expand All @@ -125,23 +125,21 @@ function create_utils_benchmark()
form_connection=0.0,
break_connection=0.0,
);
options=Options(;
unary_operators=[sin, cos],
binary_operators=[+, -, *, /],
mutation_weights,
options = Options(;
unary_operators=[sin, cos], binary_operators=[+, -, *, /], mutation_weights
);
recorder=RecordType();
temperature=1.0;
curmaxsize=20;
rss=RunningSearchStatistics(; options);
trees=[
recorder = RecordType();
temperature = 1.0;
curmaxsize = 20;
rss = RunningSearchStatistics(; options);
trees = [
gen_random_tree_fixed_size(15, options, nfeatures, Float64) for _ in 1:100
];
expressions=[
expressions = [
Expression(tree; operators=options.operators, variable_names=["x1"]) for
tree in trees
];
members=[
members = [
PopMember(dataset, expression, options; deterministic=false) for
expression in expressions
]
Expand All @@ -155,14 +153,14 @@ function create_utils_benchmark()
end,
seconds = 20,
setup = (
nfeatures=1;
T=Float64;
dataset=Dataset(randn(nfeatures, 512), randn(512));
ntrees=($ntrees);
trees=[
nfeatures = 1;
T = Float64;
dataset = Dataset(randn(nfeatures, 512), randn(512));
ntrees = ($ntrees);
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:ntrees
];
members=[
members = [
PopMember(dataset, tree, $options; deterministic=false) for tree in trees
]
)
Expand All @@ -181,9 +179,9 @@ function create_utils_benchmark()
compute_complexity(tree, $options)
end,
setup = (
T=Float64;
nfeatures=3;
trees=[
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for
i in 1:($ntrees)
]
Expand All @@ -199,9 +197,9 @@ function create_utils_benchmark()
SymbolicRegression.MutationFunctionsModule.randomly_rotate_tree!(tree)
end,
setup = (
T=Float64;
nfeatures=3;
trees=[
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for
i in 1:($ntrees)
]
Expand All @@ -216,9 +214,9 @@ function create_utils_benchmark()
)
end,
setup = (
T=Float64;
nfeatures=3;
trees=[
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees)
]
)
Expand All @@ -242,9 +240,9 @@ function create_utils_benchmark()
check_constraints(tree, $options, $options.maxsize)
end,
setup = (
T=Float64;
nfeatures=3;
trees=[
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees)
]
)
Expand Down
6 changes: 3 additions & 3 deletions src/CheckConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ function check_constraints(
return true
end

check_constraints(ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions)::Bool = check_constraints(
ex, options, options.maxsize
)
check_constraints(
ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions
)::Bool = check_constraints(ex, options, options.maxsize)

end
5 changes: 2 additions & 3 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,8 @@ function apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N}
if all(_is_valid, x)
return _apply_operator(op, x...)
else
example_vector = something(
map(xi -> xi isa ValidVector ? xi : nothing, x)...
)::ValidVector
example_vector =
something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector
expected_return_type = Base.promote_op(
_apply_operator, typeof(op), map(typeof, x)...
)
Expand Down
19 changes: 12 additions & 7 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,12 @@ function activate_env_on_workers(
)
verbosity > 0 && @info "Activating environment on workers."
@everywhere procs begin
Base.MainInclude.eval(quote
using Pkg
Pkg.activate($$project_path)
end)
Base.MainInclude.eval(
quote
using Pkg
Pkg.activate($$project_path)
end,
)
end
end

Expand Down Expand Up @@ -289,9 +291,12 @@ function import_module_on_workers(
all_extensions = vcat(relevant_extensions, @something(worker_imports, Symbol[]))

for ext in all_extensions
push!(expr.args, quote
using $ext: $ext
end)
push!(
expr.args,
quote
using $ext: $ext
end,
)
end

verbosity > 0 && if isempty(relevant_extensions)
Expand Down
8 changes: 4 additions & 4 deletions src/ConstantOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using ..CoreModule:
AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options, dataset_fraction
using ..UtilsModule: get_birth_order, PerTaskCache, stable_get!
using ..LossFunctionsModule: eval_loss, loss_to_cost
using ..PopMemberModule: PopMember
using ..PopMemberModule: AbstractPopMember

function can_optimize(::AbstractExpression{T}, options) where {T}
return can_optimize(T, options)
Expand All @@ -31,7 +31,7 @@ end
member::P,
options::AbstractOptions;
rng::AbstractRNG=default_rng(),
)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:AbstractPopMember{T,L}}
can_optimize(member.tree, options) || return (member, 0.0)
nconst = count_constants_for_optimization(member.tree)
nconst == 0 && return (member, 0.0)
Expand Down Expand Up @@ -63,7 +63,7 @@ count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex)

function _optimize_constants(
dataset, member::P, options, algorithm, optimizer_options, rng
)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}}
)::Tuple{P,Float64} where {T,L,P<:AbstractPopMember{T,L}}
tree = member.tree
x0, refs = get_scalar_constants(tree)
@assert count_constants_for_optimization(tree) == length(x0)
Expand All @@ -76,7 +76,7 @@ function _optimize_constants(
end
function _optimize_constants_inner(
f::F, fg!::G, x0, refs, dataset, member::P, options, algorithm, optimizer_options, rng
)::Tuple{P,Float64} where {F,G,T,L,P<:PopMember{T,L}}
)::Tuple{P,Float64} where {F,G,T,L,P<:AbstractPopMember{T,L}}
obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing
f
else
Expand Down
22 changes: 12 additions & 10 deletions src/ExpressionBuilder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using DynamicExpressions:
using ..CoreModule: AbstractOptions, Dataset
using ..HallOfFameModule: HallOfFame
using ..PopulationModule: Population
using ..PopMemberModule: PopMember
using ..PopMemberModule: PopMember, AbstractPopMember

import DynamicExpressions: get_operators
import ..CoreModule: create_expression
Expand Down Expand Up @@ -134,8 +134,10 @@ end
)
end
function embed_metadata(
vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L,H<:Union{HallOfFame,Population,PopMember}}
vec::Vector{<:Union{HallOfFame,Population,AbstractPopMember}},
options::AbstractOptions,
dataset::Dataset{T,L},
) where {T,L}
return map(Fix{2}(Fix{3}(embed_metadata, dataset), options), vec)
end
end
Expand All @@ -153,8 +155,8 @@ function strip_metadata(
return with_metadata(ex; init_params(options, dataset, ex, Val(false))...)
end
function strip_metadata(
member::PopMember, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
member::PM, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L,PM<:PopMember{T,L}}
return PopMember(
strip_metadata(member.tree, options, dataset),
member.cost,
Expand All @@ -165,14 +167,14 @@ function strip_metadata(
deterministic=options.deterministic,
)
end
function strip_metadata(
pop::Population, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
@unstable function strip_metadata(
pop::P, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L,P<:Population{T,L}}
return Population(map(member -> strip_metadata(member, options, dataset), pop.members))
end
function strip_metadata(
hof::HallOfFame, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
hof::H, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L,N,PM<:AbstractPopMember,H<:HallOfFame{T,L,N,PM}}
return HallOfFame(
map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists
)
Expand Down
13 changes: 7 additions & 6 deletions src/HallOfFame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer
using ..CoreModule:
AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression, init_value
using ..ComplexityModule: compute_complexity
using ..PopMemberModule: PopMember
using ..PopMemberModule: AbstractPopMember, PopMember
using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING
using Printf: @sprintf

Expand All @@ -23,8 +23,10 @@ have been set, you can run `.members[exists]`.
These are ordered by complexity, with `.members[1]` the member with complexity 1.
- `exists::Array{Bool,1}`: Whether the member at the given complexity has been set.
"""
struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}}
members::Array{PopMember{T,L,N},1}
struct HallOfFame{
T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N}
}
members::Array{PM,1}
exists::Array{Bool,1} #Whether it has been set
end
function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N}
Expand Down Expand Up @@ -69,7 +71,7 @@ function HallOfFame(
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
base_tree = create_expression(init_value(T), options, dataset)

return HallOfFame{T,L,typeof(base_tree)}(
return HallOfFame{T,L,typeof(base_tree),PopMember{T,L,typeof(base_tree)}}(
[
PopMember(
copy(base_tree),
Expand All @@ -95,9 +97,8 @@ end
"""
function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N}
# TODO - remove dataset from args.
P = PopMember{T,L,N}
# Dominating pareto curve - must be better than all simpler equations
dominating = P[]
dominating = similar(hallOfFame.members, 0)
for size in eachindex(hallOfFame.members)
if !hallOfFame.exists[size]
continue
Expand Down
6 changes: 3 additions & 3 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,9 @@ end
function get_equation_strings_for(
::AbstractSingletargetSRRegressor, trees, options, variable_names
)
return (t -> string_tree(t, options; variable_names=variable_names, pretty=false)).(
trees
)
return (
t -> string_tree(t, options; variable_names=variable_names, pretty=false)
).(trees)
end
function get_equation_strings_for(
::AbstractMultitargetSRRegressor, trees, options, variable_names
Expand Down
4 changes: 2 additions & 2 deletions src/Migration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MigrationModule

using ..CoreModule: AbstractOptions
using ..PopulationModule: Population
using ..PopMemberModule: PopMember, reset_birth!
using ..PopMemberModule: AbstractPopMember, reset_birth!
using ..UtilsModule: poisson_sample

"""
Expand All @@ -14,7 +14,7 @@ to do so. The original migrant population is not modified. Pass with, e.g.,
"""
function migrate!(
migration::Pair{Vector{PM},P}, options::AbstractOptions; frac::AbstractFloat
) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}}
) where {T,L,N,PM<:AbstractPopMember{T,L,N},P<:Population{T,L,N}}
base_pop = migration.second
population_size = length(base_pop.members)
mean_number_replaced = population_size * frac
Expand Down
Loading
Loading