Skip to content

Commit 4d3e1cf

Browse files
committed
refactor: infer_popmember_type
1 parent 2dc0645 commit 4d3e1cf

File tree

4 files changed

+35
-32
lines changed

4 files changed

+35
-32
lines changed

src/PopMember.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module PopMemberModule
22

33
using DispatchDoctor: @unstable
44
using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree
5-
import DynamicExpressions: constructorof
5+
import DynamicExpressions: constructorof, with_type_parameters
66
using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression
77
import ..CoreModule.OptionsModule: default_popmember_type
88
import ..ComplexityModule: compute_complexity
@@ -267,4 +267,14 @@ function popmember_type end
267267
return PopMember{T,L,N}
268268
end
269269

270+
@inline function with_type_parameters(
271+
::Type{<:PopMember}, ::Type{T}, ::Type{L}, ::Type{N}
272+
) where {T,L,N}
273+
return PopMember{T,L,N}
274+
end
275+
276+
@inline function expression_type(::Type{<:AbstractPopMember{<:Any,<:Any,N}}) where {N}
277+
return N
278+
end
279+
270280
end

src/SearchUtils.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using ..CoreModule:
2323
Dataset, AbstractOptions, Options, RecordType, max_features, create_expression
2424
using ..ComplexityModule: compute_complexity
2525
using ..PopulationModule: Population
26-
using ..PopMemberModule: PopMember, AbstractPopMember, with_expression_type
26+
using ..PopMemberModule: PopMember, AbstractPopMember
2727
using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve
2828
using ..ConstantOptimizationModule: optimize_constants
2929
using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen
@@ -34,6 +34,15 @@ using ..CheckConstraintsModule: check_constraints
3434

3535
function logging_callback! end
3636

37+
@unstable @inline function infer_popmember_type(
38+
::Type{T}, ::Type{L}, ::Type{D}, options
39+
) where {T,L,D<:Dataset}
40+
NodeType = with_type_parameters(options.node_type, T)
41+
N = Base.promote_op(create_expression, NodeType, typeof(options), D)
42+
N in (Any, Union{}) && error("Failed to infer expression type")
43+
return with_type_parameters(options.popmember_type, T, L, N)
44+
end
45+
3746
"""
3847
@filtered_async expr
3948
@@ -793,26 +802,13 @@ end
793802

794803
"""Parse user-provided guess expressions and convert them into optimized
795804
`PopMember` objects for each output dataset."""
796-
@unstable function parse_guesses(
797-
::Type{P},
798-
guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}},
799-
datasets::Vector{D},
800-
options::AbstractOptions,
801-
) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}}
802-
return _parse_guesses_impl(P, guesses, datasets, options)
803-
end
804-
805-
# Deal with non-concrete PopMember types
806805
@unstable function parse_guesses(
807806
::Type{P},
808807
guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}},
809808
datasets::Vector{D},
810809
options::AbstractOptions,
811810
) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}}
812-
NodeType = with_type_parameters(options.node_type, T)
813-
N = Base.promote_op(create_expression, NodeType, typeof(options), D)
814-
N in (Any, Union{}) && error("Failed to infer expression type")
815-
ConcreteP = with_expression_type(P, N)
811+
ConcreteP = infer_popmember_type(T, L, D, options)
816812
return _parse_guesses_impl(ConcreteP, guesses, datasets, options)
817813
end
818814

src/SymbolicRegression.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ using .MutationFunctionsModule:
297297
using .InterfaceDynamicExpressionsModule:
298298
@extend_operators, require_copy_to_workers, make_example_inputs
299299
using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func
300-
using .PopMemberModule: AbstractPopMember, PopMember, reset_birth!, popmember_type
300+
using .PopMemberModule:
301+
AbstractPopMember, PopMember, reset_birth!, popmember_type, expression_type
301302
using .CoreModule.UtilsModule: get_birth_order
302303
using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample
303304
using .HallOfFameModule:
@@ -339,7 +340,8 @@ using .SearchUtilsModule:
339340
get_cur_maxsize,
340341
update_hall_of_fame!,
341342
parse_guesses,
342-
logging_callback!
343+
logging_callback!,
344+
infer_popmember_type
343345
using .LoggingModule: AbstractSRLogger, SRLogger, get_logger
344346
using .TemplateExpressionModule:
345347
TemplateExpression, TemplateStructure, TemplateExpressionSpec, ParamVector, has_params
@@ -631,20 +633,8 @@ end
631633
@recorder record["options"] = "$(options)"
632634

633635
nout = length(datasets)
634-
example_dataset = first(datasets)
635-
example_ex = create_expression(init_value(T), options, example_dataset)
636-
NT = typeof(example_ex)
637-
# Create a prototype member to get the concrete type
638-
prototype_member = options.popmember_type(
639-
copy(example_ex),
640-
L(0),
641-
L(Inf),
642-
options,
643-
1; # complexity
644-
parent=-1,
645-
deterministic=options.deterministic,
646-
)
647-
PMType = typeof(prototype_member)
636+
PMType = infer_popmember_type(T, L, D, options)
637+
NT = expression_type(PMType)
648638
PopType = Population{T,L,NT,PMType}
649639
HallOfFameType = HallOfFame{T,L,NT,PMType}
650640
WorkerOutputType = get_worker_output_type(

test/test_abstract_popmember.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@
7979

8080
@unstable DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember
8181

82+
# Define with_type_parameters for CustomPopMember
83+
@unstable function DynamicExpressions.with_type_parameters(
84+
::Type{<:CustomPopMember}, ::Type{T}, ::Type{L}, ::Type{N}
85+
) where {T,L,N}
86+
return CustomPopMember{T,L,N}
87+
end
88+
8289
# Define copy for CustomPopMember
8390
function Base.copy(p::CustomPopMember)
8491
return CustomPopMember(

0 commit comments

Comments
 (0)