Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 34 additions & 36 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ function create_utils_benchmark()
suite["best_of_sample"] = @benchmarkable(
best_of_sample(pop, rss, $options),
setup = (
nfeatures=1;
dataset=Dataset(randn(nfeatures, 32), randn(32));
pop=Population(dataset; npop=100, nlength=20, options=($options), nfeatures);
rss=RunningSearchStatistics(; options=($options))
nfeatures = 1;
dataset = Dataset(randn(nfeatures, 32), randn(32));
pop = Population(dataset; npop=100, nlength=20, options=($options), nfeatures);
rss = RunningSearchStatistics(; options=($options))
)
)

Expand All @@ -110,9 +110,9 @@ function create_utils_benchmark()
end
end,
setup = (
nfeatures=1;
dataset=Dataset(randn(nfeatures, 32), randn(32));
mutation_weights=MutationWeights(;
nfeatures = 1;
dataset = Dataset(randn(nfeatures, 32), randn(32));
mutation_weights = MutationWeights(;
mutate_constant=1.0,
mutate_operator=1.0,
swap_operands=1.0,
Expand All @@ -125,23 +125,21 @@ function create_utils_benchmark()
form_connection=0.0,
break_connection=0.0,
);
options=Options(;
unary_operators=[sin, cos],
binary_operators=[+, -, *, /],
mutation_weights,
options = Options(;
unary_operators=[sin, cos], binary_operators=[+, -, *, /], mutation_weights
);
recorder=RecordType();
temperature=1.0;
curmaxsize=20;
rss=RunningSearchStatistics(; options);
trees=[
recorder = RecordType();
temperature = 1.0;
curmaxsize = 20;
rss = RunningSearchStatistics(; options);
trees = [
gen_random_tree_fixed_size(15, options, nfeatures, Float64) for _ in 1:100
];
expressions=[
expressions = [
Expression(tree; operators=options.operators, variable_names=["x1"]) for
tree in trees
];
members=[
members = [
PopMember(dataset, expression, options; deterministic=false) for
expression in expressions
]
Expand All @@ -155,14 +153,14 @@ function create_utils_benchmark()
end,
seconds = 20,
setup = (
nfeatures=1;
T=Float64;
dataset=Dataset(randn(nfeatures, 512), randn(512));
ntrees=($ntrees);
trees=[
nfeatures = 1;
T = Float64;
dataset = Dataset(randn(nfeatures, 512), randn(512));
ntrees = ($ntrees);
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:ntrees
];
members=[
members = [
PopMember(dataset, tree, $options; deterministic=false) for tree in trees
]
)
Expand All @@ -181,9 +179,9 @@ function create_utils_benchmark()
compute_complexity(tree, $options)
end,
setup = (
T=Float64;
nfeatures=3;
trees=[
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for
i in 1:($ntrees)
]
Expand All @@ -199,9 +197,9 @@ function create_utils_benchmark()
SymbolicRegression.MutationFunctionsModule.randomly_rotate_tree!(tree)
end,
setup = (
T=Float64;
nfeatures=3;
trees=[
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for
i in 1:($ntrees)
]
Expand All @@ -216,9 +214,9 @@ function create_utils_benchmark()
)
end,
setup = (
T=Float64;
nfeatures=3;
trees=[
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees)
]
)
Expand All @@ -242,9 +240,9 @@ function create_utils_benchmark()
check_constraints(tree, $options, $options.maxsize)
end,
setup = (
T=Float64;
nfeatures=3;
trees=[
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees)
]
)
Expand Down
6 changes: 3 additions & 3 deletions src/CheckConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ function check_constraints(
return true
end

check_constraints(ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions)::Bool = check_constraints(
ex, options, options.maxsize
)
check_constraints(
ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions
)::Bool = check_constraints(ex, options, options.maxsize)

end
5 changes: 2 additions & 3 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,8 @@ function apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N}
if all(_is_valid, x)
return _apply_operator(op, x...)
else
example_vector = something(
map(xi -> xi isa ValidVector ? xi : nothing, x)...
)::ValidVector
example_vector =
something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector
expected_return_type = Base.promote_op(
_apply_operator, typeof(op), map(typeof, x)...
)
Expand Down
19 changes: 12 additions & 7 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,12 @@ function activate_env_on_workers(
)
verbosity > 0 && @info "Activating environment on workers."
@everywhere procs begin
Base.MainInclude.eval(quote
using Pkg
Pkg.activate($$project_path)
end)
Base.MainInclude.eval(
quote
using Pkg
Pkg.activate($$project_path)
end,
)
end
end

Expand Down Expand Up @@ -289,9 +291,12 @@ function import_module_on_workers(
all_extensions = vcat(relevant_extensions, @something(worker_imports, Symbol[]))

for ext in all_extensions
push!(expr.args, quote
using $ext: $ext
end)
push!(
expr.args,
quote
using $ext: $ext
end,
)
end

verbosity > 0 && if isempty(relevant_extensions)
Expand Down
6 changes: 3 additions & 3 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,9 @@ end
function get_equation_strings_for(
::AbstractSingletargetSRRegressor, trees, options, variable_names
)
return (t -> string_tree(t, options; variable_names=variable_names, pretty=false)).(
trees
)
return (
t -> string_tree(t, options; variable_names=variable_names, pretty=false)
).(trees)
end
function get_equation_strings_for(
::AbstractMultitargetSRRegressor, trees, options, variable_names
Expand Down
5 changes: 2 additions & 3 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ end
break
end
end
found_degree == 0 &&
error("Operator $(op) is not in the operator set.")
found_degree == 0 && error("Operator $(op) is not in the operator set.")
(found_degree, found_idx)
end,
new_max_nesting_dict = [
Expand All @@ -167,7 +166,7 @@ end
end
end
found_degree == 0 &&
error("Operator $(nested_op) is not in the operator set.")
error("Operator $(nested_op) is not in the operator set.")
(found_degree, found_idx)
end
(nested_degree, nested_idx, max_nesting)
Expand Down
18 changes: 10 additions & 8 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,18 @@ to avoid spam when worker processes exit normally.
macro filtered_async(expr)
return esc(
quote
$(Base).errormonitor(@async begin
try
$expr
catch ex
if !(ex isa $(Distributed).ProcessExitedException)
rethrow(ex)
$(Base).errormonitor(
@async begin
try
$expr
catch ex
if !(ex isa $(Distributed).ProcessExitedException)
rethrow(ex)
end
end
end
end)
end
)
end,
)
end

Expand Down
10 changes: 6 additions & 4 deletions test/manual_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ procs = addprocs(2)
using Test, Pkg
project_path = splitdir(Pkg.project().path)[1]
@everywhere procs begin
Base.MainInclude.eval(quote
using Pkg
Pkg.activate($$project_path)
end)
Base.MainInclude.eval(
quote
using Pkg
Pkg.activate($$project_path)
end,
)
end
@everywhere using SymbolicRegression
@everywhere _inv(x::Float32)::Float32 = 1.0f0 / x
Expand Down
6 changes: 3 additions & 3 deletions test/test_units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ end
!has_cos(member.tree) || any(
t ->
t.degree == 1 &&
options.operators.unaops[t.op] == cos &&
Node(Float64; feature=1) in t &&
compute_complexity(t, options) > 1,
options.operators.unaops[t.op] == cos &&
Node(Float64; feature=1) in t &&
compute_complexity(t, options) > 1,
get_tree(member.tree),
) for member in dominating
]
Expand Down
4 changes: 2 additions & 2 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ function simple_bottomk(x, k)
end

array_options = [
(n=n, seed=seed, T=T) for
n in (1, 5, 20, 50, 100, 1000), seed in 1:10, T in (Float32, Float64, Int)
(n=n, seed=seed, T=T) for n in (1, 5, 20, 50, 100, 1000), seed in 1:10,
T in (Float32, Float64, Int)
]

@testset "argmin_fast" begin
Expand Down