Skip to content

Commit 26665da

Browse files
committed
fix: parse_guesses for custom AbstractPopMember
1 parent d0b13dd commit 26665da

File tree

5 files changed

+37
-13
lines changed

5 files changed

+37
-13
lines changed

src/ExpressionBuilder.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ end
118118
options;
119119
complexity=compute_complexity(member, options),
120120
parent_ref=member.ref,
121-
deterministic=options.deterministic,
122121
)
123122
end
124123
function embed_metadata(

src/MLJInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ using ..CoreModule:
3939
ExpressionSpec,
4040
get_expression_type,
4141
check_warm_start_compatibility
42-
using ..CoreModule.OptionsModule:
43-
DEFAULT_OPTIONS, OPTION_DESCRIPTIONS, default_popmember_type
42+
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
43+
using ..PopMemberModule: default_popmember_type
4444
using ..ComplexityModule: compute_complexity
4545
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
4646
using ..UtilsModule: subscriptify, @ignore

src/PopMember.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ end
201201

202202
"""
203203
create_child(parent::P, tree::AbstractExpression{T}, cost, loss, options;
204-
complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where {T,L,P<:PopMember{T,L}}
204+
complexity::Union{Int,Nothing}=nothing, parent_ref) where {T,L,P<:PopMember{T,L}}
205205
206206
Create a new PopMember with a potentially different expression type.
207207
Used by embed_metadata where the expression gains metadata.
@@ -214,7 +214,6 @@ function create_child(
214214
options;
215215
complexity::Union{Int,Nothing}=nothing,
216216
parent_ref,
217-
kwargs...,
218217
) where {T,L,P<:PopMember{T,L}}
219218
actual_complexity = @something complexity compute_complexity(tree, options)
220219
return PopMember(
@@ -230,7 +229,7 @@ end
230229

231230
"""
232231
create_child(parents::Tuple{P,P}, tree, cost, loss, options;
233-
complexity::Union{Int,Nothing}=nothing, parent_ref, kwargs...) where P<:AbstractPopMember
232+
complexity::Union{Int,Nothing}=nothing, parent_ref) where P<:AbstractPopMember
234233
235234
Create a new PopMember from two parents (crossover case).
236235
Custom types should override to blend their additional fields.
@@ -243,7 +242,6 @@ function create_child(
243242
options;
244243
complexity::Union{Int,Nothing}=nothing,
245244
parent_ref,
246-
kwargs...,
247245
) where {T,L,P<:PopMember{T,L}}
248246
actual_complexity = @something complexity compute_complexity(tree, options)
249247
return PopMember(
@@ -263,4 +261,10 @@ function popmember_type end
263261
@unstable default_popmember_type() = PopMember
264262
@unstable constructorof(::Type{<:PopMember}) = PopMember
265263

264+
@inline function with_expression_type(
265+
::Type{<:PopMember{T,L}}, ::Type{N}
266+
) where {T,L,N<:AbstractExpression{T}}
267+
return PopMember{T,L,N}
268+
end
269+
266270
end

src/SearchUtils.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@ using DispatchDoctor: @unstable
1212
using Logging: AbstractLogger
1313

1414
using DynamicExpressions:
15-
AbstractExpression, string_tree, parse_expression, EvalOptions, with_type_parameters
15+
AbstractExpression,
16+
string_tree,
17+
parse_expression,
18+
EvalOptions,
19+
with_type_parameters,
20+
constructorof
1621
using ..UtilsModule: subscriptify
17-
using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features
22+
using ..CoreModule:
23+
Dataset, AbstractOptions, Options, RecordType, max_features, create_expression
1824
using ..ComplexityModule: compute_complexity
1925
using ..PopulationModule: Population
20-
using ..PopMemberModule: PopMember, AbstractPopMember
26+
using ..PopMemberModule: PopMember, AbstractPopMember, with_expression_type
2127
using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve
2228
using ..ConstantOptimizationModule: optimize_constants
2329
using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen
@@ -800,7 +806,9 @@ function parse_guesses(
800806
dataset = datasets[j]
801807
for g in guess_lists[j]
802808
ex = _parse_guess_expression(T, g, dataset, options)
803-
member = PopMember(dataset, ex, options; deterministic=options.deterministic)
809+
member = constructorof(P)(
810+
dataset, ex, options; deterministic=options.deterministic
811+
)
804812
if options.should_optimize_constants
805813
member, _ = optimize_constants(dataset, member, options)
806814
end
@@ -818,6 +826,21 @@ function parse_guesses(
818826
end
819827
return out
820828
end
829+
830+
# Deal with non-concrete PopMember types
831+
function parse_guesses(
832+
::Type{P},
833+
guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}},
834+
datasets::Vector{D},
835+
options::AbstractOptions,
836+
) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}}
837+
NodeType = with_type_parameters(options.node_type, T)
838+
N = Base.promote_op(create_expression, NodeType, typeof(options), D)
839+
N === Any && error("Failed to infer expression type")
840+
ConcreteP = with_expression_type(P, N)
841+
return parse_guesses(ConcreteP, guesses, datasets, options)
842+
end
843+
821844
function _make_vector_vector(guesses, nout)
822845
if nout == 1
823846
if guesses isa AbstractVector{<:AbstractVector}

test/test_abstract_popmember.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@
101101
options;
102102
complexity::Union{Int,Nothing}=nothing,
103103
parent_ref,
104-
kwargs...,
105104
) where {T,L}
106105
actual_complexity = @something complexity SymbolicRegression.compute_complexity(
107106
tree, options
@@ -126,7 +125,6 @@
126125
options;
127126
complexity::Union{Int,Nothing}=nothing,
128127
parent_ref,
129-
kwargs...,
130128
) where {T,L,N<:AbstractExpression{T}}
131129
actual_complexity = @something complexity SymbolicRegression.compute_complexity(
132130
tree, options

0 commit comments

Comments
 (0)