Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/SymbolicRegressionSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module SymbolicRegressionSymbolicUtilsExt

using SymbolicUtils: Symbolic
using SymbolicRegression: AbstractExpressionNode, AbstractExpression, Node, Options
using SymbolicRegression: AbstractExpressionNode, AbstractExpression, Options
using SymbolicRegression.MLJInterfaceModule: AbstractSymbolicRegressor, get_options
using DynamicExpressions: get_tree, get_operators

Expand Down
3 changes: 3 additions & 0 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ using DynamicExpressions.InterfacesModule:
ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments
using DynamicExpressions.ValueInterfaceModule: is_valid_array

using ..UtilsModule: @intentional_import
using ..ConstantOptimizationModule: ConstantOptimizationModule as CO
using ..CoreModule: get_safe_op

@intentional_import Interfaces

abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end

"""
Expand Down
20 changes: 4 additions & 16 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,23 +251,12 @@ function activate_env_on_workers(
end

function import_module_on_workers(
procs,
filename::String,
@nospecialize(worker_imports::Union{Vector{Symbol},Nothing}),
verbosity,
procs, @nospecialize(worker_imports::Union{Vector{Symbol},Nothing}), verbosity
)
loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules]

included_as_local = "SymbolicRegression" ∉ loaded_modules_head_worker
expr = if included_as_local
quote
include($filename)
using .SymbolicRegression
end
else
quote
using SymbolicRegression
end
expr = quote
using SymbolicRegression
end

# Need to import any extension code, if loaded on head node
Expand Down Expand Up @@ -367,7 +356,6 @@ function configure_workers(;
options::AbstractOptions,
@nospecialize(worker_imports::Union{Vector{Symbol},Nothing}),
project_path,
file,
exeflags::Cmd,
verbosity,
example_dataset::Dataset,
Expand All @@ -382,7 +370,7 @@ function configure_workers(;
end

if we_created_procs
import_module_on_workers(procs, file, worker_imports, verbosity)
import_module_on_workers(procs, worker_imports, verbosity)
end

move_functions_to_workers(procs, options, example_dataset, verbosity)
Expand Down
20 changes: 20 additions & 0 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,25 @@ using .ExpressionSpecModule:
get_expression_options,
get_node_type
using .InterfaceDataTypesModule: init_value, sample_value, mutate_value
using .UtilsModule: @intentional_import

# These are imported from submodules for use by parent module
@intentional_import RecordType, DATA_TYPE, LOSS_TYPE
@intentional_import Dataset, BasicDataset, SubDataset, is_weighted, has_units
@intentional_import max_features, batch, get_indices, get_full_dataset, dataset_fraction
@intentional_import AbstractMutationWeights, MutationWeights, sample_mutation
@intentional_import AbstractOptions, Options, ComplexityMapping, WarmStartIncompatibleError
@intentional_import check_warm_start_compatibility, get_safe_op
@intentional_import plus, sub, mult, square, cube, pow, safe_pow, div
@intentional_import safe_log, safe_log2, safe_log10, safe_log1p, safe_asin
@intentional_import safe_acos, safe_atan, safe_acosh, safe_atanh_clip
@intentional_import safe_sqrt, safe_cbrt, neg, greater, cond, relu
@intentional_import greater_equal, less, less_equal
@intentional_import logical_or, logical_and, gamma, safe_atanh, safe_csch
@intentional_import operator_specialization, specialized_options
@intentional_import erf, erfc, atanh_clip
@intentional_import AbstractExpressionSpec, ExpressionSpec, get_expression_type
@intentional_import get_expression_options, get_node_type
@intentional_import init_value, sample_value, mutate_value

end
2 changes: 1 addition & 1 deletion src/DimensionalAnalysis.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DimensionalAnalysisModule

using DynamicExpressions:
AbstractExpression, AbstractExpressionNode, get_tree, get_child, tree_mapreduce
AbstractExpression, AbstractExpressionNode, get_tree, tree_mapreduce
using DynamicQuantities: Quantity, DimensionError, AbstractQuantity, constructorof

using ..CoreModule: AbstractOptions, Dataset
Expand Down
2 changes: 1 addition & 1 deletion src/ExpressionSpec.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ExpressionSpecModule

using DynamicExpressions: AbstractExpression, Expression, AbstractExpressionNode, Node
using DynamicExpressions: Expression, Node

abstract type AbstractExpressionSpec end

Expand Down
2 changes: 1 addition & 1 deletion src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using DynamicExpressions:
GraphNode,
EvalOptions
using DynamicQuantities: dimension, ustrip
using ..CoreModule: AbstractOptions, Dataset
using ..CoreModule: AbstractOptions
using ..CoreModule.OptionsModule: inverse_opmap
using ..UtilsModule: subscriptify

Expand Down
9 changes: 8 additions & 1 deletion src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,19 @@ using ..CoreModule:
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
using ..ComplexityModule: compute_complexity
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
using ..UtilsModule: subscriptify, @ignore
using ..UtilsModule: subscriptify, @ignore, @intentional_import
using ..LoggingModule: AbstractSRLogger
using ..TemplateExpressionModule: TemplateExpression

import ..equation_search

# These imports are needed for MLJ generated code
@intentional_import AbstractADType, AbstractExpressionNode, AbstractExpressionSpec
@intentional_import AbstractLogger, AbstractMutationWeights, AbstractOperatorEnum
@intentional_import ComplexityMapping, Dataset, Expression, MutationWeights
@intentional_import Node, SupervisedLoss, TemplateExpression, compute_complexity
@intentional_import default_node_type, get_tree

abstract type AbstractSymbolicRegressor <: MMI.Deterministic end

abstract type AbstractSingletargetSRRegressor <: AbstractSymbolicRegressor end
Expand Down
42 changes: 21 additions & 21 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ end
tree = rtree[]

if !successful_mutation
@recorder begin
@recorder options begin
tmp_recorder["result"] = "reject"
tmp_recorder["reason"] = "failed_constraint_check"
end
Expand All @@ -271,7 +271,7 @@ end
num_evals += dataset_fraction(dataset)

if isnan(after_cost)
@recorder begin
@recorder options begin
tmp_recorder["result"] = "reject"
tmp_recorder["reason"] = "nan_loss"
end
Expand Down Expand Up @@ -315,7 +315,7 @@ end
end

if probChange < rand()
@recorder begin
@recorder options begin
tmp_recorder["result"] = "reject"
tmp_recorder["reason"] = "annealing_or_frequency"
end
Expand All @@ -334,7 +334,7 @@ end
num_evals,
)
else
@recorder begin
@recorder options begin
tmp_recorder["result"] = "accept"
tmp_recorder["reason"] = "pass"
end
Expand Down Expand Up @@ -429,7 +429,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = mutate_constant(tree, temperature, options)
@recorder recorder["type"] = "mutate_constant"
@recorder options recorder["type"] = "mutate_constant"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -443,7 +443,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = mutate_operator(tree, options)
@recorder recorder["type"] = "mutate_operator"
@recorder options recorder["type"] = "mutate_operator"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -458,7 +458,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = mutate_feature(tree, nfeatures)
@recorder recorder["type"] = "mutate_feature"
@recorder options recorder["type"] = "mutate_feature"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -472,7 +472,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = swap_operands(tree)
@recorder recorder["type"] = "swap_operands"
@recorder options recorder["type"] = "swap_operands"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -488,10 +488,10 @@ function mutate!(
) where {N<:AbstractExpression,P<:PopMember}
if rand() < 0.5
tree = append_random_op(tree, options, nfeatures)
@recorder recorder["type"] = "add_node:append"
@recorder options recorder["type"] = "add_node:append"
else
tree = prepend_random_op(tree, options, nfeatures)
@recorder recorder["type"] = "add_node:prepend"
@recorder options recorder["type"] = "add_node:prepend"
end
return MutationResult{N,P}(; tree=tree)
end
Expand All @@ -507,7 +507,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = insert_random_op(tree, options, nfeatures)
@recorder recorder["type"] = "insert_node"
@recorder options recorder["type"] = "insert_node"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -521,7 +521,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = delete_random_op!(tree)
@recorder recorder["type"] = "delete_node"
@recorder options recorder["type"] = "delete_node"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -535,7 +535,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = form_random_connection!(tree)
@recorder recorder["type"] = "form_connection"
@recorder options recorder["type"] = "form_connection"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -549,7 +549,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = break_random_connection!(tree)
@recorder recorder["type"] = "break_connection"
@recorder options recorder["type"] = "break_connection"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -563,7 +563,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
tree = randomly_rotate_tree!(tree)
@recorder recorder["type"] = "rotate_tree"
@recorder options recorder["type"] = "rotate_tree"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -581,7 +581,7 @@ function mutate!(
@assert options.should_simplify
simplify_tree!(tree, options.operators)
tree = combine_operators(tree, options.operators)
@recorder recorder["type"] = "simplify"
@recorder options recorder["type"] = "simplify"
return MutationResult{N,P}(;
member=PopMember(
tree,
Expand All @@ -607,7 +607,7 @@ function mutate!(
kws...,
) where {T,N<:AbstractExpression{T},P<:PopMember}
tree = randomize_tree(tree, curmaxsize, options, nfeatures)
@recorder recorder["type"] = "randomize"
@recorder options recorder["type"] = "randomize"
return MutationResult{N,P}(; tree=tree)
end

Expand All @@ -622,7 +622,7 @@ function mutate!(
kws...,
) where {N<:AbstractExpression,P<:PopMember}
cur_member, new_num_evals = optimize_constants(dataset, member, options)
@recorder recorder["type"] = "optimize"
@recorder options recorder["type"] = "optimize"
return MutationResult{N,P}(;
member=cur_member, num_evals=new_num_evals, return_immediately=true
)
Expand All @@ -638,7 +638,7 @@ function mutate!(
parent_ref,
kws...,
) where {N<:AbstractExpression,P<:PopMember}
@recorder begin
@recorder options begin
recorder["type"] = "identity"
recorder["result"] = "accept"
recorder["reason"] = "identity"
Expand Down Expand Up @@ -686,7 +686,7 @@ function crossover_generation(
break
end
if num_tries > max_tries
@recorder begin
@recorder options begin
recorder["result"] = "reject"
recorder["reason"] = "failed_constraint_check"
end
Expand Down Expand Up @@ -723,7 +723,7 @@ function crossover_generation(
deterministic=options.deterministic,
)::P

@recorder begin
@recorder options begin
recorder["result"] = "accept"
recorder["reason"] = "pass"
end
Expand Down
4 changes: 1 addition & 3 deletions src/MutationFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ using DynamicExpressions:
with_contents,
constructorof,
set_node!,
count_nodes,
has_constants,
has_operators,
get_child,
set_child!,
max_degree
set_child!
using ..CoreModule: AbstractOptions, DATA_TYPE, init_value, sample_value

import ..CoreModule: mutate_value
Expand Down
6 changes: 4 additions & 2 deletions src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ using DynamicQuantities: UnionAbstractQuantity
using SpecialFunctions: erf, erfc
using Base: @deprecate
using DynamicDiff: ForwardDiff
using ..ProgramConstantsModule: DATA_TYPE
using ...UtilsModule: @ignore
using ...UtilsModule: @ignore, @intentional_import

@intentional_import erf, erfc

#TODO - actually add these operators to the module!

# TODO: Should this be limited to AbstractFloat instead?
Expand Down
1 change: 0 additions & 1 deletion src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Optim: Optim
using DynamicExpressions:
OperatorEnum,
AbstractOperatorEnum,
Expression,
default_node_type,
AbstractExpression,
AbstractExpressionNode
Expand Down
1 change: 0 additions & 1 deletion src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ using Random: default_rng, AbstractRNG
using ..CoreModule:
AbstractOptions,
Dataset,
SubDataset,
DATA_TYPE,
AbstractMutationWeights,
AbstractExpressionSpec,
Expand Down
2 changes: 1 addition & 1 deletion src/Population.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module PopulationModule
using StatsBase: StatsBase
using DispatchDoctor: @unstable
using DynamicExpressions: AbstractExpression, string_tree
using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
using ..CoreModule: AbstractOptions, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
using ..ComplexityModule: compute_complexity
using ..LossFunctionsModule: eval_cost, update_baseline_loss!
using ..AdaptiveParsimonyModule: RunningSearchStatistics
Expand Down
8 changes: 4 additions & 4 deletions src/Recorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ module RecorderModule

using ..CoreModule: RecordType

"Assumes that `options` holds the user options::AbstractOptions"
macro recorder(ex)
quote
if $(esc(:options)).use_recorder
"Conditionally execute code based on options.use_recorder"
macro recorder(options, ex)
return quote
if $(esc(options)).use_recorder
$(esc(ex))
end
end
Expand Down
Loading
Loading