diff --git a/Project.toml b/Project.toml index f5a02e829..0e1cc78f5 100644 --- a/Project.toml +++ b/Project.toml @@ -36,12 +36,14 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [extensions] SymbolicRegressionEnzymeExt = "Enzyme" SymbolicRegressionJSON3Ext = "JSON3" SymbolicRegressionMooncakeExt = "Mooncake" SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" +SymbolicRegressionTablesExt = "Tables" [compat] ADTypes = "^1.4.0" @@ -74,4 +76,5 @@ StatsBase = "0.33, 0.34" StyledStrings = "1" SymbolicUtils = "0.19, ^1.0.5, 2, 3" TOML = "<0.0.1, 1" +Tables = "1" julia = "1.10" diff --git a/ext/SymbolicRegressionTablesExt.jl b/ext/SymbolicRegressionTablesExt.jl new file mode 100644 index 000000000..0e07aed45 --- /dev/null +++ b/ext/SymbolicRegressionTablesExt.jl @@ -0,0 +1,46 @@ +module SymbolicRegressionTablesExt + +using Tables: Tables +import SymbolicRegression.HallOfFameModule: HOFRows, member_to_row + +# Make HOFRows compatible with the Tables.jl interface +# HOFRows is already iterable via Base.iterate, so we just need to declare compatibility +Tables.istable(::Type{<:HOFRows}) = true +Tables.rowaccess(::Type{<:HOFRows}) = true +Tables.rows(view::HOFRows) = view # Return itself since it's already iterable + +# Provide schema information for better Tables.jl integration +function Tables.schema(rows::HOFRows) + if isempty(rows.members) + # Empty table - can't infer schema + return nothing + end + + # Get column names from either column specs or first row + if rows.columns !== nothing + # Use explicit column specs + names = Tuple(col.name for col in rows.columns) + # We can't reliably infer types without evaluating, so return nothing for types + return Tables.Schema(names, nothing) + else + # Infer from first row + first_row = member_to_row( + rows.members[1], rows.dataset, rows.options; pretty=rows.pretty + ) + if rows.include_score + # Will add score in iteration + names = (keys(first_row)..., :score) + else + names = keys(first_row) + end + # Get types from first row + types = if rows.include_score + (typeof.(values(first_row))..., Float64) # Assume Float64 for score + else + typeof.(values(first_row)) + end + return Tables.Schema(names, types) + end +end + +end diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl index 585a6ef78..7fa7798d8 100644 --- a/src/ConstantOptimization.jl +++ b/src/ConstantOptimization.jl @@ -17,7 +17,7 @@ using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options, dataset_fraction using ..UtilsModule: get_birth_order, PerTaskCache, stable_get! using ..LossFunctionsModule: eval_loss, loss_to_cost -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember function can_optimize(::AbstractExpression{T}, options) where {T} return can_optimize(T, options) @@ -31,7 +31,7 @@ end member::P, options::AbstractOptions; rng::AbstractRNG=default_rng(), -)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,N,P<:AbstractPopMember{T,L,N}} can_optimize(member.tree, options) || return (member, 0.0) nconst = count_constants_for_optimization(member.tree) nconst == 0 && return (member, 0.0) @@ -63,7 +63,7 @@ count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex) function _optimize_constants( dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {T,L,N,P<:AbstractPopMember{T,L,N}} tree = member.tree x0, refs = get_scalar_constants(tree) @assert count_constants_for_optimization(tree) == length(x0) @@ -76,7 +76,7 @@ function _optimize_constants( end function _optimize_constants_inner( f::F, fg!::G, x0, refs, dataset, member::P, options, algorithm, optimizer_options, rng -)::Tuple{P,Float64} where {F,G,T,L,P<:PopMember{T,L}} +)::Tuple{P,Float64} where {F,G,T,L,N,P<:AbstractPopMember{T,L,N}} obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing f else diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index 4de4e56a2..99de5bdc6 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -11,7 +11,8 @@ using DynamicExpressions: using ..CoreModule: AbstractOptions, Dataset using ..HallOfFameModule: HallOfFame using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember, create_child +using ..ComplexityModule: compute_complexity import DynamicExpressions: get_operators import ..CoreModule: create_expression @@ -107,16 +108,16 @@ end return with_metadata(ex; init_params(options, dataset, ex, Val(true))...) end function embed_metadata( - member::PopMember, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L} - return PopMember( + member::PM, options::AbstractOptions, dataset::Dataset{T,L} + ) where {T,L,N,PM<:AbstractPopMember{T,L,N}} + return create_child( + member, embed_metadata(member.tree, options, dataset), member.cost, member.loss, - nothing; - member.ref, - member.parent, - deterministic=options.deterministic, + options; + complexity=compute_complexity(member, options), + parent_ref=member.ref, ) end function embed_metadata( @@ -135,7 +136,7 @@ end end function embed_metadata( vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L} - ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} + ) where {T,L,H<:Union{HallOfFame,Population,AbstractPopMember}} return map(Fix{2}(Fix{3}(embed_metadata, dataset), options), vec) end end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index c90fbe3a4..a657f1d37 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -2,16 +2,18 @@ module HallOfFameModule using StyledStrings: @styled_str using DynamicExpressions: AbstractExpression, string_tree +using DispatchDoctor: @unstable using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression, init_value using ..ComplexityModule: compute_complexity -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember +import ..PopMemberModule: popmember_type using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING using Printf: @sprintf """ - HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} + HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N}} List of the best members seen all time in `.members`, with `.members[c]` being the best member seen at complexity c. Including only the members which actually @@ -19,15 +21,19 @@ have been set, you can run `.members[exists]`. # Fields -- `members::Array{PopMember{T,L,N},1}`: List of the best members seen all time. +- `members::Array{PM,1}`: List of the best members seen all time. These are ordered by complexity, with `.members[1]` the member with complexity 1. - `exists::Array{Bool,1}`: Whether the member at the given complexity has been set. """ -struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct HallOfFame{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} exists::Array{Bool,1} #Whether it has been set end -function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N} +function Base.show( + io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N,PM} +) where {T,L,N,PM} println(io, "HallOfFame{...}:") for i in eachindex(hof.members, hof.exists) s_member, s_exists = if hof.exists[i] @@ -47,8 +53,8 @@ function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where end return nothing end -function Base.eltype(::Union{HOF,Type{HOF}}) where {T,L,N,HOF<:HallOfFame{T,L,N}} - return PopMember{T,L,N} +function Base.eltype(::Union{HOF,Type{HOF}}) where {T,L,N,PM,HOF<:HallOfFame{T,L,N,PM}} + return PM end """ @@ -68,17 +74,36 @@ function HallOfFame( options::AbstractOptions, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} base_tree = create_expression(init_value(T), options, dataset) + PM = options.popmember_type - return HallOfFame{T,L,typeof(base_tree)}( + # Create a prototype member to get the concrete type + prototype = PM( + copy(base_tree), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + + PMtype = typeof(prototype) + + return HallOfFame{T,L,typeof(base_tree),PMtype}( [ - PopMember( - copy(base_tree), - L(0), - L(Inf), - options; - parent=-1, - deterministic=options.deterministic, - ) for i in 1:(options.maxsize) + if i == 1 + prototype + else + PM( + copy(base_tree), + L(0), + L(Inf), + options, + 1; # complexity + parent=-1, + deterministic=options.deterministic, + ) + end for i in 1:(options.maxsize) ], [false for i in 1:(options.maxsize)], ) @@ -93,11 +118,10 @@ end """ calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,P}) where {T<:DATA_TYPE,L<:LOSS_TYPE} """ -function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} +function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N,PM}) where {T,L,N,PM} # TODO - remove dataset from args. - P = PopMember{T,L,N} # Dominating pareto curve - must be better than all simpler equations - dominating = P[] + dominating = PM[] for size in eachindex(hallOfFame.members) if !hallOfFame.exists[size] continue @@ -123,57 +147,415 @@ function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} return dominating end -let header_parts = ( - rpad(styled"{bold:{underline:Complexity}}", 10), - rpad(styled"{bold:{underline:Loss}}", 9), - rpad(styled"{bold:{underline:Score}}", 9), - styled"{bold:{underline:Equation}}", +""" + member_to_row(member::AbstractPopMember, dataset::Dataset, options::AbstractOptions; + pretty::Bool=true) + +Convert a PopMember to a row representation for display/export. + +This is the primary extension point for custom PopMember types. Users can override this +method to include additional fields in the output. + +# Arguments +- `member`: The population member to convert +- `dataset`: Dataset for formatting equation strings +- `options`: Options controlling complexity and equation formatting +- `pretty`: Whether to use pretty-printing for equations (default: true) + +# Returns +A NamedTuple containing the member's data. Default fields are: +- `complexity`: Expression complexity +- `loss`: Raw loss value +- `cost`: Cost including complexity penalty +- `birth`: Birth order/generation +- `ref`: Unique reference ID +- `parent`: Parent reference ID +- `equation`: Formatted equation string + +# Example 1: Adding custom fields to a custom PopMember +```julia +function SymbolicRegression.HallOfFameModule.member_to_row( + member::MyCustomPopMember, + dataset::Dataset, + options::AbstractOptions; + kwargs... +) + base = invoke(member_to_row, Tuple{AbstractPopMember, Dataset, AbstractOptions}, + member, dataset, options; kwargs...) + return merge(base, (my_field = member.custom_data,)) +end +``` + +# Example 2: Displaying custom fields in the Hall of Fame +After extending `member_to_row`, create custom columns to display your fields: +```julia +using Printf + +custom_columns = [ + HOFColumn(:complexity, "C", row -> row.complexity, string, 5, :right), + HOFColumn(:loss, "Loss", row -> row.loss, x -> @sprintf("%.3e", x), 9, :right), + HOFColumn(:my_field, "MyField", row -> row.my_field, x -> @sprintf("%.2f", x), 10, :right), + HOFColumn(:equation, "Equation", row -> row.equation, identity, nothing, :left) +] + +# Display with custom columns +str = string_dominating_pareto_curve(hof, dataset, options; columns=custom_columns) +println(str) + +# Or export via Tables.jl with custom columns +rows = hof_rows(hof, dataset, options; columns=custom_columns) +using DataFrames +df = DataFrame(rows) +``` +""" +function member_to_row( + member::AbstractPopMember, dataset::Dataset, options::AbstractOptions; pretty::Bool=true +) + eqn_string = string_tree( + member.tree, + options; + display_variable_names=dataset.display_variable_names, + X_sym_units=dataset.X_sym_units, + y_sym_units=dataset.y_sym_units, + pretty=pretty, + ) + prefix = make_prefix(member.tree, options, dataset) + eqn_string = prefix * eqn_string + return ( + complexity=compute_complexity(member, options), + loss=member.loss, + cost=member.cost, + birth=member.birth, + ref=member.ref, + parent=member.parent, + equation=eqn_string, ) - @eval const HEADER = join($(header_parts), " ") - @eval const HEADER_WITHOUT_SCORE = join($(header_parts[[1, 2, 4]]), " ") end -show_score_column(options::AbstractOptions) = options.loss_scale == :log +""" + HOFColumn + +Specification for a column in Hall of Fame display and export. + +# Fields +- `name::Symbol`: Column identifier (key in the row NamedTuple) +- `header::String`: Display header text +- `getter::Function`: Function `(row::NamedTuple) -> value` to extract/compute column value +- `formatter::Function`: Function `(value) -> String` for display formatting (display only) +- `width::Union{Int,Nothing}`: Display width (nothing for auto-sizing) +- `alignment::Symbol`: Text alignment - `:left`, `:right`, or `:center` + +# Example +```julia +# Simple column that extracts an existing field +complexity_col = HOFColumn( + :complexity, "Complexity", + row -> row.complexity, + x -> string(x), + 10, :right +) + +# Computed column +r2_col = HOFColumn( + :r2, "R²", + row -> compute_r2(row), # Custom computation + x -> @sprintf("%.3f", x), + 8, :right +) +``` +""" +struct HOFColumn + name::Symbol + header::String + getter::Function + formatter::Function + width::Union{Int,Nothing} + alignment::Symbol +end +""" + default_columns(options::AbstractOptions) -> Vector{HOFColumn} + +Return the default column specifications for Hall of Fame display. + +The default columns are: +- Complexity (right-aligned, width 10) +- Loss (right-aligned, width 9, scientific notation) +- Score (conditional on `options.loss_scale == :log`, right-aligned, width 9) +- Equation (left-aligned, auto-width) + +Users can customize by modifying this vector or creating their own. +""" +function default_columns(options::AbstractOptions) + cols = HOFColumn[ + HOFColumn( + :complexity, + "Complexity", + row -> row.complexity, + x -> @sprintf("%d", x), + 10, + :right, + ), + HOFColumn(:loss, "Loss", row -> row.loss, x -> @sprintf("%.3e", x), 9, :right), + ] + + # Add score column for logarithmic loss scale + if options.loss_scale == :log + push!( + cols, + HOFColumn( + :score, "Score", row -> row.score, x -> @sprintf("%.3e", x), 9, :right + ), + ) + end + + # Equation column (special handling in display due to wrapping) + push!( + cols, + HOFColumn(:equation, "Equation", row -> row.equation, identity, nothing, :left), + ) + + return cols +end + +""" + HOFRows + +A lazy iterator for HallOfFame members that computes rows on-demand. +This struct implements the Tables.jl interface for easy export to DataFrames, CSV, etc. + +# Fields +- `members`: Vector of PopMembers to iterate over +- `dataset`: Dataset for formatting equations +- `options`: Options for complexity and formatting +- `include_score`: Whether to compute and include Pareto improvement scores +- `pretty`: Whether to use pretty-printing for equations +- `columns`: Optional column specifications (nothing = all columns from member_to_row) +""" +struct HOFRows{PM<:AbstractPopMember} + members::Vector{PM} + dataset::Dataset + options::AbstractOptions + include_score::Bool + pretty::Bool + columns::Union{Vector{HOFColumn},Nothing} +end + +# Helper function to create a single row with optional score and column filtering +@unstable function _make_row(view::HOFRows, i::Int, scores) + # Get full row from member_to_row + row = member_to_row(view.members[i], view.dataset, view.options; pretty=view.pretty) + + # Add score if computed + row = scores === nothing ? row : (; row..., score=scores[i]) + + # Apply column filtering if specified + if view.columns !== nothing + # Build filtered row using column getters + filtered_values = [col.getter(row) for col in view.columns] + filtered_names = Tuple(col.name for col in view.columns) + return NamedTuple{filtered_names}(filtered_values) + end + + return row +end + +# Make HOFRows iterable +Base.length(view::HOFRows) = length(view.members) +Base.eltype(::Type{<:HOFRows}) = NamedTuple + +function Base.iterate(view::HOFRows) + isempty(view.members) && return nothing + + # Compute all scores upfront if needed + scores = view.include_score ? compute_scores(view.members, view.options) : nothing + state = (scores, 1) + + row = _make_row(view, 1, scores) + return (row, state) +end + +function Base.iterate(view::HOFRows, state) + scores, i = state + i += 1 + i > length(view.members) && return nothing + + row = _make_row(view, i, scores) + return (row, (scores, i)) +end + +""" + hof_rows(hof::HallOfFame, dataset::Dataset, options::AbstractOptions; + pareto_only::Bool=true, include_score::Bool=pareto_only, + pretty::Bool=true, columns::Union{Vector{HOFColumn},Nothing}=nothing) + +This function returns an `HOFRows` object. + +# Arguments +- `hof`: The HallOfFame to export +- `dataset`: Dataset for formatting equations +- `options`: Options controlling complexity and formatting +- `pareto_only`: Only include Pareto frontier members (default: true) +- `include_score`: Include Pareto improvement scores (default: same as `pareto_only`) +- `pretty`: Use pretty-printing for equations (default: true) +- `columns`: Optional column specifications (default: nothing = all columns from member_to_row) + +# Returns +An `HOFRows` object that can be used with Tables.jl-compatible consumers like +`DataFrame`, `CSV.write`, etc. + +# Examples +```julia +# Get a Tables.jl view of the Pareto frontier +rows = hof_rows(hof, dataset, options) + +# Convert to DataFrame (requires DataFrames.jl) +using DataFrames +df = DataFrame(rows) + +# Get all members without scores +all_rows = hof_rows(hof, dataset, options; pareto_only=false, include_score=false) + +# Get only specific columns +custom_cols = [ + HOFColumn(:complexity, "Complexity", row -> row.complexity, string, 10, :right), + HOFColumn(:loss, "Loss", row -> row.loss, x -> @sprintf("%.3e", x), 9, :right) +] +filtered_rows = hof_rows(hof, dataset, options; columns=custom_cols) +``` +""" +function hof_rows( + hof::HallOfFame, + dataset::Dataset, + options::AbstractOptions; + pareto_only::Bool=true, + include_score::Bool=pareto_only, + pretty::Bool=true, + columns::Union{Vector{HOFColumn},Nothing}=nothing, +) + members = if pareto_only + calculate_pareto_frontier(hof) + else + [m for (m, ex) in zip(hof.members, hof.exists) if ex] + end + + return HOFRows(members, dataset, options, include_score, pretty, columns) +end + +""" + string_dominating_pareto_curve( + hallOfFame, dataset, options; + width::Union{Integer,Nothing}=nothing, + pretty::Bool=true, + columns::Union{Vector{HOFColumn},Nothing}=nothing + ) + +Format the Pareto frontier as a pretty-printed string table. + +# Arguments +- `hallOfFame`: The HallOfFame to display +- `dataset`: Dataset for formatting equations +- `options`: Options controlling complexity and formatting +- `width`: Terminal width (default: 100) +- `pretty`: Use pretty-printing for equations (default: true) +- `columns`: Column specifications (default: nothing = use default_columns(options)) + +# Example with custom columns +```julia +custom_cols = [ + HOFColumn(:complexity, "C", row -> row.complexity, string, 5, :right), + HOFColumn(:loss, "Loss", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right), + HOFColumn(:equation, "Equation", row -> row.equation, identity, nothing, :left) +] +str = string_dominating_pareto_curve(hof, dataset, options; columns=custom_cols) +``` +""" function string_dominating_pareto_curve( - hallOfFame, dataset, options; width::Union{Integer,Nothing}=nothing, pretty::Bool=true + hallOfFame, + dataset, + options; + width::Union{Integer,Nothing}=nothing, + pretty::Bool=true, + columns::Union{Vector{HOFColumn},Nothing}=nothing, ) + # Use default columns if not specified + cols = columns === nothing ? default_columns(options) : columns + terminal_width = (width === nothing) ? 100 : max(100, width::Integer) buffer = AnnotatedIOBuffer(IOBuffer()) + + # Print top border println(buffer, '─'^(terminal_width - 1)) - if show_score_column(options) - println(buffer, HEADER) - else - println(buffer, HEADER_WITHOUT_SCORE) - end - formatted = format_hall_of_fame(hallOfFame, options) - for (tree, score, loss, complexity) in - zip(formatted.trees, formatted.scores, formatted.losses, formatted.complexities) - eqn_string = string_tree( - tree, - options; - display_variable_names=dataset.display_variable_names, - X_sym_units=dataset.X_sym_units, - y_sym_units=dataset.y_sym_units, - pretty, - ) - prefix = make_prefix(tree, options, dataset) - eqn_string = prefix * eqn_string - stats_columns_string = if show_score_column(options) - @sprintf("%-10d %-8.3e %-8.3e ", complexity, loss, score) + # Build header from column specs + header_parts = map(cols) do col + header_text = styled"{bold:{underline:$(col.header)}}" + if col.width === nothing + # Last column (typically equation) - no padding + header_text else - @sprintf("%-10d %-8.3e ", complexity, loss) + # Fixed-width column - pad to width + rpad(header_text, col.width) end - left_cols_width = length(stats_columns_string) - print(buffer, stats_columns_string) - print( - buffer, - wrap_equation_string( - eqn_string, left_cols_width + length(prefix), terminal_width - ), - ) end + println(buffer, join(header_parts, " ")) + + # Get rows (without column filtering, we'll format ourselves) + rows_view = hof_rows( + hallOfFame, dataset, options; pareto_only=true, include_score=true, pretty=pretty + ) + members = rows_view.members + + # Format each row + for (i, full_row) in enumerate(rows_view) + member = members[i] + + # Format all columns except the last one (which may need wrapping) + formatted_cols = String[] + for (col_idx, col) in enumerate(cols) + value = col.getter(full_row) + formatted = col.formatter(value) + + if col_idx == length(cols) + # Last column - handle separately for wrapping + # Calculate left margin from previous columns + left_cols_width = sum( + length(formatted_cols[j]) + 2 for j in 1:(length(formatted_cols)) + ) + + # Handle equation prefix if it's an equation column + if col.name == :equation && haskey(full_row, :equation) + prefix = make_prefix(member.tree, options, dataset) + wrapped = wrap_equation_string( + formatted, left_cols_width + length(prefix), terminal_width + ) + print(buffer, join(formatted_cols, " ")) + print(buffer, " ") + print(buffer, wrapped) + else + # Non-equation last column - just print + push!(formatted_cols, formatted) + println(buffer, join(formatted_cols, " ")) + end + else + # Non-last column - format with alignment and width + if col.width !== nothing + if col.alignment == :right + formatted = lpad(formatted, col.width) + elseif col.alignment == :center + formatted = lpad( + rpad(formatted, (col.width + length(formatted)) ÷ 2), col.width + ) + else # :left + formatted = rpad(formatted, col.width) + end + end + push!(formatted_cols, formatted) + end + end + end + + # Print bottom border print(buffer, '─'^(terminal_width - 1)) return dump_buffer(buffer) end @@ -214,31 +596,38 @@ function wrap_equation_string(eqn_string, left_cols_width, terminal_width) return dump_buffer(buffer) end -function format_hall_of_fame(hof::HallOfFame{T,L}, options) where {T,L} - dominating = calculate_pareto_frontier(hof) +""" + compute_scores(members::Vector{<:AbstractPopMember}, options::AbstractOptions) - # Only check for negative losses if using logarithmic scaling - options.loss_scale == :log && for member in dominating - if member.loss < 0.0 - throw( - DomainError( - member.loss, - "Your loss function must be non-negative. To allow negative losses, set the `loss_scale` to linear, or consider wrapping your loss inside an exponential.", - ), - ) - end - end +Compute improvement scores for an ordered sequence of members. - trees = [member.tree for member in dominating] - losses = [member.loss for member in dominating] - complexities = [compute_complexity(member, options) for member in dominating] - scores = Array{L}(undef, length(dominating)) +Scores measure the improvement in loss per unit complexity compared to the previous +member in the sequence. The first member always has a score of zero. + +This function works with any ordered sequence of members (e.g., Pareto frontier, +complexity-sorted members, etc.). + +# Arguments +- `members`: Vector of PopMembers in the desired order +- `options`: Options controlling the loss scale (`:linear` or `:log`) - cur_loss = typemax(L) - last_loss = cur_loss +# Returns +Vector of scores with the same length as `members` +""" +function compute_scores( + members::Vector{<:AbstractPopMember{T,L,N}}, options::AbstractOptions +) where {T,L,N} + isempty(members) && return L[] + + scores = Vector{L}(undef, length(members)) + + complexities = [compute_complexity(member, options) for member in members] + losses = [member.loss for member in members] + + last_loss = typemax(L) last_complexity = zero(eltype(complexities)) - for i in 1:length(dominating) + for i in eachindex(members) complexity = complexities[i] cur_loss = losses[i] delta_c = complexity - last_complexity @@ -254,6 +643,30 @@ function format_hall_of_fame(hof::HallOfFame{T,L}, options) where {T,L} last_loss = cur_loss last_complexity = complexity end + + return scores +end + +function format_hall_of_fame(hof::HallOfFame{T,L}, options) where {T,L} + dominating = calculate_pareto_frontier(hof) + + # Only check for negative losses if using logarithmic scaling + options.loss_scale == :log && for member in dominating + if member.loss < 0.0 + throw( + DomainError( + member.loss, + "Your loss function must be non-negative. To allow negative losses, set the `loss_scale` to linear, or consider wrapping your loss inside an exponential.", + ), + ) + end + end + + trees = [member.tree for member in dominating] + losses = [member.loss for member in dominating] + complexities = [compute_complexity(member, options) for member in dominating] + scores = compute_scores(dominating, options) + return (; trees, scores, losses, complexities) end function compute_direct_score(cur_loss, last_loss, delta_c) @@ -276,4 +689,7 @@ function format_hall_of_fame(hof::AbstractVector{<:HallOfFame}, options) end # TODO: Re-use this in `string_dominating_pareto_curve` +# Type accessor for HallOfFame +popmember_type(::Type{<:HallOfFame{T,L,N,PM}}) where {T,L,N,PM} = PM + end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index bb2fddfb0..ccc6c33ee 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -40,6 +40,7 @@ using ..CoreModule: get_expression_type, check_warm_start_compatibility using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS +using ..PopMemberModule: default_popmember_type using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore diff --git a/src/Migration.jl b/src/Migration.jl index f7fe61b89..3988b81ea 100644 --- a/src/Migration.jl +++ b/src/Migration.jl @@ -2,7 +2,7 @@ module MigrationModule using ..CoreModule: AbstractOptions using ..PopulationModule: Population -using ..PopMemberModule: PopMember, reset_birth! +using ..PopMemberModule: AbstractPopMember, PopMember, reset_birth! using ..UtilsModule: poisson_sample """ @@ -14,7 +14,7 @@ to do so. The original migrant population is not modified. Pass with, e.g., """ function migrate!( migration::Pair{Vector{PM},P}, options::AbstractOptions; frac::AbstractFloat -) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}} +) where {T,L,N,PM<:AbstractPopMember{T,L,N},P<:Population{T,L,N,PM}} base_pop = migration.second population_size = length(base_pop.members) mean_number_replaced = population_size * frac diff --git a/src/Mutate.jl b/src/Mutate.jl index f5fd88457..412d7ea06 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -22,7 +22,7 @@ using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember, create_child using ..MutationFunctionsModule: mutate_constant, mutate_operator, @@ -40,10 +40,10 @@ using ..MutationFunctionsModule: using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder -abstract type AbstractMutationResult{N<:AbstractExpression,P<:PopMember} end +abstract type AbstractMutationResult{N<:AbstractExpression,P<:AbstractPopMember} end """ - MutationResult{N<:AbstractExpression,P<:PopMember} + MutationResult{N<:AbstractExpression,P<:AbstractPopMember} Represents the result of a mutation operation in the genetic programming algorithm. This struct is used to return values from `mutate!` functions. @@ -61,7 +61,8 @@ This struct encapsulates the result of a mutation operation. Either a new expres Return the `member` if you want to return immediately, and have computed the loss value as part of the mutation. """ -struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationResult{N,P} +struct MutationResult{N<:AbstractExpression,P<:AbstractPopMember} <: + AbstractMutationResult{N,P} tree::Union{N,Nothing} member::Union{P,Nothing} num_evals::Float64 @@ -73,7 +74,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes member::Union{_P,Nothing}=nothing, num_evals::Float64=0.0, return_immediately::Bool=false, - ) where {_N<:AbstractExpression,_P<:PopMember} + ) where {_N<:AbstractExpression,_P<:AbstractPopMember} @assert( (tree === nothing) ⊻ (member === nothing), "Mutation result must return either a tree or a pop member, not both" @@ -83,7 +84,7 @@ struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationRes end """ - condition_mutation_weights!(weights::AbstractMutationWeights, member::PopMember, options::AbstractOptions, curmaxsize::Int, nfeatures::Int) + condition_mutation_weights!(weights::AbstractMutationWeights, member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, nfeatures::Int) Adjusts the mutation weights based on the properties of the current member and options. @@ -93,7 +94,7 @@ Note that the weights were already copied, so you don't need to worry about muta # Arguments - `weights::AbstractMutationWeights`: The mutation weights to be adjusted. -- `member::PopMember`: The current population member being mutated. +- `member::AbstractPopMember`: The current population member being mutated. - `options::AbstractOptions`: The options that guide the mutation process. - `curmaxsize::Int`: The current maximum size constraint for the member's expression tree. - `nfeatures::Int`: The number of features available in the dataset. @@ -104,7 +105,7 @@ function condition_mutation_weights!( options::AbstractOptions, curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:AbstractExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:AbstractExpression,P<:AbstractPopMember{T,L,N}} tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 @@ -159,7 +160,7 @@ Use this to modify how `mutate_constant` changes for an expression type. function condition_mutate_constant!( ::Type{<:AbstractExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) @@ -181,7 +182,7 @@ end tmp_recorder::RecordType, )::Tuple{ P,Bool,Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:AbstractPopMember{T,L,N}} parent_ref = member.ref num_evals = 0.0 @@ -253,14 +254,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -277,14 +277,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -321,14 +320,13 @@ end end mutation_accepted = false return ( - PopMember( + create_child( + member, copy_into!(node_storage, member.tree), before_cost, before_loss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, + options; + parent_ref=parent_ref, ), mutation_accepted, num_evals, @@ -339,25 +337,22 @@ end tmp_recorder["reason"] = "pass" end mutation_accepted = true - return ( - PopMember( - tree, - after_cost, - after_loss, - options, - newSize; - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, + new_member = create_child( + member, + tree, + after_cost, + after_loss, + options; + complexity=newSize, + parent_ref=parent_ref, ) + return (new_member, mutation_accepted, num_evals) end end @generated function _dispatch_mutations!( tree::AbstractExpression, - member::PopMember, + member::AbstractPopMember, mutation_choice::Symbol, weights::W, options::AbstractOptions; @@ -386,7 +381,7 @@ end mutation_weights::AbstractMutationWeights, options::AbstractOptions; kws..., - ) where {N<:AbstractExpression,P<:PopMember,S} + ) where {N<:AbstractExpression,P<:AbstractPopMember,S} Perform a mutation on the given `tree` and `member` using the specified mutation type `S`. Various `kws` are provided to access other data needed for some mutations. @@ -414,7 +409,7 @@ so it can always return immediately. """ function mutate!( ::N, ::P, ::Val{S}, ::AbstractMutationWeights, ::AbstractOptions; kws... -) where {N<:AbstractExpression,P<:PopMember,S} +) where {N<:AbstractExpression,P<:AbstractPopMember,S} return error("Unknown mutation choice: $S") end @@ -427,7 +422,7 @@ function mutate!( recorder::RecordType, temperature, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_constant(tree, temperature, options) @recorder recorder["type"] = "mutate_constant" return MutationResult{N,P}(; tree=tree) @@ -441,7 +436,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_operator(tree, options) @recorder recorder["type"] = "mutate_operator" return MutationResult{N,P}(; tree=tree) @@ -456,7 +451,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = mutate_feature(tree, nfeatures) @recorder recorder["type"] = "mutate_feature" return MutationResult{N,P}(; tree=tree) @@ -470,7 +465,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = swap_operands(tree) @recorder recorder["type"] = "swap_operands" return MutationResult{N,P}(; tree=tree) @@ -485,7 +480,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} if rand() < 0.5 tree = append_random_op(tree, options, nfeatures) @recorder recorder["type"] = "add_node:append" @@ -505,7 +500,7 @@ function mutate!( recorder::RecordType, nfeatures, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = insert_random_op(tree, options, nfeatures) @recorder recorder["type"] = "insert_node" return MutationResult{N,P}(; tree=tree) @@ -519,7 +514,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = delete_random_op!(tree) @recorder recorder["type"] = "delete_node" return MutationResult{N,P}(; tree=tree) @@ -533,7 +528,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = form_random_connection!(tree) @recorder recorder["type"] = "form_connection" return MutationResult{N,P}(; tree=tree) @@ -547,7 +542,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = break_random_connection!(tree) @recorder recorder["type"] = "break_connection" return MutationResult{N,P}(; tree=tree) @@ -561,7 +556,7 @@ function mutate!( options::AbstractOptions; recorder::RecordType, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} tree = randomly_rotate_tree!(tree) @recorder recorder["type"] = "rotate_tree" return MutationResult{N,P}(; tree=tree) @@ -577,22 +572,15 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @assert options.should_simplify simplify_tree!(tree, options.operators) tree = combine_operators(tree, options.operators) @recorder recorder["type"] = "simplify" - return MutationResult{N,P}(; - member=PopMember( - tree, - member.cost, - member.loss, - options; - parent=parent_ref, - deterministic=options.deterministic, - ), - return_immediately=true, + new_member = create_child( + member, tree, member.cost, member.loss, options; parent_ref=parent_ref ) + return MutationResult{N,P}(; member=new_member, return_immediately=true) end function mutate!( @@ -605,7 +593,7 @@ function mutate!( curmaxsize, nfeatures, kws..., -) where {T,N<:AbstractExpression{T},P<:PopMember} +) where {T,N<:AbstractExpression{T},P<:AbstractPopMember} tree = randomize_tree(tree, curmaxsize, options, nfeatures) @recorder recorder["type"] = "randomize" return MutationResult{N,P}(; tree=tree) @@ -620,7 +608,7 @@ function mutate!( recorder::RecordType, dataset::Dataset, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} cur_member, new_num_evals = optimize_constants(dataset, member, options) @recorder recorder["type"] = "optimize" return MutationResult{N,P}(; @@ -637,21 +625,15 @@ function mutate!( recorder::RecordType, parent_ref, kws..., -) where {N<:AbstractExpression,P<:PopMember} +) where {N<:AbstractExpression,P<:AbstractPopMember} @recorder begin recorder["type"] = "identity" recorder["result"] = "accept" recorder["reason"] = "identity" end return MutationResult{N,P}(; - member=PopMember( - tree, - member.cost, - member.loss, - options, - compute_complexity(tree, options); - parent=parent_ref, - deterministic=options.deterministic, + member=create_child( + member, tree, member.cost, member.loss, options; parent_ref=parent_ref ), return_immediately=true, ) @@ -665,7 +647,7 @@ function crossover_generation( curmaxsize::Int, options::AbstractOptions; recorder::RecordType=RecordType(), -)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} +)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:AbstractPopMember{T,L,N}} tree1 = member1.tree tree2 = member2.tree crossover_accepted = false @@ -704,23 +686,23 @@ function crossover_generation( ) num_evals += 2 * dataset_fraction(dataset) - baby1 = PopMember( + baby1 = create_child( + (member1, member2), child_tree1::AbstractExpression, after_cost1, after_loss1, - options, - afterSize1; - parent=member1.ref, - deterministic=options.deterministic, + options; + complexity=afterSize1, + parent_ref=member1.ref, )::P - baby2 = PopMember( + baby2 = create_child( + (member1, member2), child_tree2::AbstractExpression, after_cost2, after_loss2, - options, - afterSize2; - parent=member2.ref, - deterministic=options.deterministic, + options; + complexity=afterSize2, + parent_ref=member2.ref, )::P @recorder begin diff --git a/src/Options.jl b/src/Options.jl index cb8d3bbf3..6ee7c4222 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -225,6 +225,8 @@ recommend_loss_function_expression(expression_type) = false create_mutation_weights(w::AbstractMutationWeights) = w create_mutation_weights(w::NamedTuple) = MutationWeights(; w...) +function default_popmember_type end + @unstable function with_max_degree_from_context( node_type, user_provided_operators, operators ) @@ -652,6 +654,7 @@ $(OPTION_DESCRIPTIONS) terminal_width::Union{Nothing,Integer}=nothing, use_recorder::Bool=false, recorder_file::AbstractString="pysr_recorder.json", + popmember_type::Type=default_popmember_type(), ### Not search options; just construction options: define_helper_functions::Bool=true, ######################################### @@ -1031,6 +1034,7 @@ $(OPTION_DESCRIPTIONS) expression_type, typeof(expression_options), typeof(set_mutation_weights), + popmember_type, turbo, bumper, deprecated_return_state::Union{Bool,Nothing}, @@ -1104,6 +1108,7 @@ $(OPTION_DESCRIPTIONS) deterministic, define_helper_functions, use_recorder, + popmember_type, ) return options diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 2c3046204..495035760 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -183,6 +183,7 @@ struct Options{ E<:AbstractExpression, EO<:NamedTuple, MW<:AbstractMutationWeights, + PM, _turbo, _bumper, _return_state, @@ -256,6 +257,7 @@ struct Options{ deterministic::Bool define_helper_functions::Bool use_recorder::Bool + popmember_type::Type{PM} end function Base.print(io::IO, @nospecialize(options::Options)) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 2717afbdc..b7c9ab2f6 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -24,7 +24,7 @@ using ..CoreModule: AbstractExpressionSpec, get_indices, ExpressionSpecModule as ES -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..LossFunctionsModule: LossFunctionsModule as LF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB @@ -102,7 +102,7 @@ end function MM.condition_mutate_constant!( ::Type{<:ParametricExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) diff --git a/src/PopMember.jl b/src/PopMember.jl index bd195a6c2..71f8707de 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -2,13 +2,32 @@ module PopMemberModule using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree +import DynamicExpressions: constructorof, with_type_parameters using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression +import ..CoreModule.OptionsModule: default_popmember_type import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_cost +""" + AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} + +Abstract type for population members. Defines the interface that all population members must implement. + +# Required fields (accessed via getproperty/setproperty!) +- `tree::N`: The expression tree +- `cost::L`: The cost including complexity penalty and normalization +- `loss::L`: The raw loss value +- `birth::Int`: Birth order/generation number +- `ref::Int`: Unique reference ID +- `parent::Int`: Parent reference ID +- `complexity::Int`: Cached complexity (accessed via getfield/setfield! for special handling) +""" +abstract type AbstractPopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} end + # Define a member of population by equation, cost, and age -mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} +mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} <: + AbstractPopMember{T,L,N} tree::N cost::L # Inludes complexity penalty, normalization loss::L # Raw loss @@ -19,7 +38,9 @@ mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} ref::Int parent::Int end -@inline function Base.setproperty!(member::PopMember, field::Symbol, value) + +# Generic interface implementations for AbstractPopMember +@inline function Base.setproperty!(member::AbstractPopMember, field::Symbol, value) if field == :complexity throw( error("Don't set `.complexity` directly. Use `recompute_complexity!` instead.") @@ -34,7 +55,7 @@ end end return setfield!(member, field, value) end -@unstable @inline function Base.getproperty(member::PopMember, field::Symbol) +@unstable @inline function Base.getproperty(member::AbstractPopMember, field::Symbol) if field == :complexity throw( error("Don't access `.complexity` directly. Use `compute_complexity` instead.") @@ -145,7 +166,7 @@ function PopMember( ) end -function Base.copy(p::P) where {P<:PopMember} +function Base.copy(p::PopMember) tree = copy(p.tree) cost = copy(p.cost) loss = copy(p.loss) @@ -153,17 +174,17 @@ function Base.copy(p::P) where {P<:PopMember} complexity = copy(getfield(p, :complexity)) ref = copy(p.ref) parent = copy(p.parent) - return P(tree, cost, loss, birth, complexity, ref, parent) + return PopMember(tree, cost, loss, birth, complexity, ref, parent) end -function reset_birth!(p::PopMember; deterministic::Bool) +function reset_birth!(p::AbstractPopMember; deterministic::Bool) p.birth = get_birth_order(; deterministic) return p end # Can read off complexity directly from pop members function compute_complexity( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) + member::AbstractPopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = getfield(member, :complexity) complexity == -1 && return recompute_complexity!(member, options; break_sharing) @@ -171,11 +192,89 @@ function compute_complexity( return complexity end function recompute_complexity!( - member::PopMember, options::AbstractOptions; break_sharing=Val(false) + member::AbstractPopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = compute_complexity(member.tree, options; break_sharing) setfield!(member, :complexity, complexity) return complexity end +""" + create_child(parent::P, tree::AbstractExpression{T}, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref) where {T,L,P<:PopMember{T,L}} + +Create a new PopMember with a potentially different expression type. +Used by embed_metadata where the expression gains metadata. +""" +function create_child( + parent::P, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, +) where {T,L,P<:PopMember{T,L}} + actual_complexity = @something complexity compute_complexity(tree, options) + return PopMember( + tree, + cost, + loss, + options, + actual_complexity; + parent=parent_ref, + deterministic=options.deterministic, + ) +end + +""" + create_child(parents::Tuple{P,P}, tree, cost, loss, options; + complexity::Union{Int,Nothing}=nothing, parent_ref) where P<:AbstractPopMember + +Create a new PopMember from two parents (crossover case). +Custom types should override to blend their additional fields. +""" +function create_child( + parents::Tuple{P,P}, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, +) where {T,L,P<:PopMember{T,L}} + actual_complexity = @something complexity compute_complexity(tree, options) + return PopMember( + tree, + cost, + loss, + options, + actual_complexity; + parent=parent_ref, + deterministic=options.deterministic, + ) +end + +# Function to extract PopMember type from Population or HallOfFame types +function popmember_type end + +@unstable default_popmember_type() = PopMember +@unstable constructorof(::Type{<:PopMember}) = PopMember + +@inline function with_expression_type( + ::Type{<:PopMember{T,L}}, ::Type{N} +) where {T,L,N<:AbstractExpression{T}} + return PopMember{T,L,N} +end + +@inline function with_type_parameters( + ::Type{<:PopMember}, ::Type{T}, ::Type{L}, ::Type{N} +) where {T,L,N} + return PopMember{T,L,N} +end + +@inline function expression_type(::Type{<:AbstractPopMember{<:Any,<:Any,N}}) where {N} + return N +end + end diff --git a/src/Population.jl b/src/Population.jl index 739ca828e..00f603258 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -2,26 +2,29 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpression, string_tree +using DynamicExpressions: AbstractExpression, string_tree, constructorof using ..CoreModule: AbstractOptions, Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: eval_cost, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutationFunctionsModule: gen_random_tree -using ..PopMemberModule: PopMember +using ..PopMemberModule: AbstractPopMember, PopMember +import ..PopMemberModule: popmember_type using ..UtilsModule: bottomk_fast, argmin_fast, PerTaskCache # A list of members of the population, with easy constructors, # which allow for random generation of new populations -struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} +struct Population{ + T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N} +} + members::Array{PM,1} n::Int end """ - Population(pop::Array{PopMember{T,L}, 1}) + Population(pop::Array{<:AbstractPopMember, 1}) Create population from list of PopMembers. """ -function Population(pop::Vector{<:PopMember}) +function Population(pop::Vector{<:AbstractPopMember}) return Population(pop, size(pop, 1)) end @@ -41,23 +44,34 @@ function Population( npop=nothing, ) where {T,L} @assert (population_size !== nothing) ⊻ (npop !== nothing) - population_size = if npop === nothing - population_size - else - npop - end - return Population( - [ - PopMember( + population_size = something(population_size, npop) + PM = options.popmember_type + + # Create first member to get concrete type + first_member = constructorof(PM)( + dataset, + gen_random_tree(nlength, options, nfeatures, T), + options; + parent=-1, + deterministic=options.deterministic, + ) + + # Use the concrete type for the array + members = typeof(first_member)[ + if i == 1 + first_member + else + constructorof(PM)( dataset, gen_random_tree(nlength, options, nfeatures, T), options; parent=-1, deterministic=options.deterministic, - ) for _ in 1:population_size - ], - population_size, - ) + ) + end for i in 1:population_size + ] + + return Population(members, population_size) end """ Population(X::AbstractMatrix{T}, y::AbstractVector{T}; @@ -90,8 +104,8 @@ Create random population and score them on the dataset. ) end -function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}} - copied_members = Vector{PopMember{T,L,N}}(undef, pop.n) +function Base.copy(pop::P)::P where {T,L,N,PM,P<:Population{T,L,N,PM}} + copied_members = Vector{PM}(undef, pop.n) Threads.@threads for i in 1:(pop.n) copied_members[i] = copy(pop.members[i]) end @@ -118,7 +132,7 @@ function _best_of_sample( members::Vector{P}, running_search_statistics::RunningSearchStatistics, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L}} +) where {T,L,N,P<:AbstractPopMember{T,L,N}} p = options.tournament_selection_p n = length(members) # == tournament_selection_n adjusted_costs = Vector{L}(undef, n) @@ -157,7 +171,7 @@ function _best_of_sample( end return members[chosen_idx] end -_get_cost(member::PopMember) = member.cost +_get_cost(member::AbstractPopMember) = member.cost const CACHED_WEIGHTS = let init_k = collect(0:5), @@ -218,4 +232,7 @@ function record_population(pop::Population, options::AbstractOptions)::RecordTyp ) end +# Type accessor for Population +popmember_type(::Type{<:Population{T,L,N,PM}}) where {T,L,N,PM} = PM + end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 58d492e8a..05efe4c09 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -12,12 +12,18 @@ using DispatchDoctor: @unstable using Logging: AbstractLogger using DynamicExpressions: - AbstractExpression, string_tree, parse_expression, EvalOptions, with_type_parameters + AbstractExpression, + string_tree, + parse_expression, + EvalOptions, + with_type_parameters, + constructorof using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features +using ..CoreModule: + Dataset, AbstractOptions, Options, RecordType, max_features, create_expression using ..ComplexityModule: compute_complexity using ..PopulationModule: Population -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ConstantOptimizationModule: optimize_constants using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen @@ -28,6 +34,15 @@ using ..CheckConstraintsModule: check_constraints function logging_callback! end +@unstable @inline function infer_popmember_type( + ::Type{T}, ::Type{L}, ::Type{D}, options +) where {T,L,D<:Dataset} + NodeType = with_type_parameters(options.node_type, T) + N = Base.promote_op(create_expression, NodeType, typeof(options), D) + N in (Any, Union{}) && error("Failed to infer expression type") + return with_type_parameters(options.popmember_type, T, L, N) +end + """ @filtered_async expr @@ -581,8 +596,9 @@ The state of the search, including the populations, worker outputs, tasks, and channels. This is used to manage the search and keep track of runtime variables in a single struct. """ -Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} <: - AbstractSearchState{T,L,N} +Base.@kwdef struct SearchState{ + T,L,N<:AbstractExpression{T},PM<:AbstractPopMember{T,L,N},WorkerOutputType,ChannelType +} <: AbstractSearchState{T,L,N} procs::Vector{Int} we_created_procs::Bool worker_output::Vector{Vector{WorkerOutputType}} @@ -590,16 +606,16 @@ Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,Cha channels::Vector{Vector{ChannelType}} worker_assignment::WorkerAssignments task_order::Vector{Tuple{Int,Int}} - halls_of_fame::Vector{HallOfFame{T,L,N}} - last_pops::Vector{Vector{Population{T,L,N}}} - best_sub_pops::Vector{Vector{Population{T,L,N}}} + halls_of_fame::Vector{HallOfFame{T,L,N,PM}} + last_pops::Vector{Vector{Population{T,L,N,PM}}} + best_sub_pops::Vector{Vector{Population{T,L,N,PM}}} all_running_search_statistics::Vector{RunningSearchStatistics} num_evals::Vector{Vector{Float64}} cycles_remaining::Vector{Int} cur_maxsizes::Vector{Int} stdin_reader::StdinReader record::Base.RefValue{RecordType} - seed_members::Vector{Vector{PopMember{T,L,N}}} + seed_members::Vector{Vector{PM}} end function save_to_file( @@ -716,7 +732,7 @@ end function update_hall_of_fame!( hall_of_fame::HallOfFame, members::Vector{PM}, options::AbstractOptions -) where {PM<:PopMember} +) where {PM<:AbstractPopMember} for member in members size = compute_complexity(member, options) valid_size = 0 < size <= options.maxsize @@ -786,12 +802,22 @@ end """Parse user-provided guess expressions and convert them into optimized `PopMember` objects for each output dataset.""" -function parse_guesses( +@unstable function parse_guesses( + ::Type{P}, + guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, + datasets::Vector{D}, + options::AbstractOptions, +) where {T,L,P<:AbstractPopMember{T,L},D<:Dataset{T,L}} + ConcreteP = infer_popmember_type(T, L, D, options) + return _parse_guesses_impl(ConcreteP, guesses, datasets, options) +end + +@inline function _parse_guesses_impl( ::Type{P}, guesses::Union{AbstractVector,AbstractVector{<:AbstractVector}}, datasets::Vector{D}, options::AbstractOptions, -) where {T,L,P<:PopMember{T,L},D<:Dataset{T,L}} +) where {T,L,N,P<:AbstractPopMember{T,L,N},D<:Dataset{T,L}} nout = length(datasets) out = [P[] for _ in 1:nout] guess_lists = _make_vector_vector(guesses, nout) @@ -799,7 +825,9 @@ function parse_guesses( dataset = datasets[j] for g in guess_lists[j] ex = _parse_guess_expression(T, g, dataset, options) - member = PopMember(dataset, ex, options; deterministic=options.deterministic) + member = constructorof(P)( + dataset, ex, options; deterministic=options.deterministic + ) if options.should_optimize_constants member, _ = optimize_constants(dataset, member, options) end @@ -817,6 +845,7 @@ function parse_guesses( end return out end + function _make_vector_vector(guesses, nout) if nout == 1 if guesses isa AbstractVector{<:AbstractVector} diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index a935a978c..4e9443f23 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -297,7 +297,9 @@ using .MutationFunctionsModule: using .InterfaceDynamicExpressionsModule: @extend_operators, require_copy_to_workers, make_example_inputs using .LossFunctionsModule: eval_loss, eval_cost, update_baseline_loss!, score_func -using .PopMemberModule: PopMember, reset_birth! +using .PopMemberModule: + AbstractPopMember, PopMember, reset_birth!, popmember_type, expression_type +using .CoreModule.UtilsModule: get_birth_order using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve @@ -338,7 +340,8 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame!, parse_guesses, - logging_callback! + logging_callback!, + infer_popmember_type using .LoggingModule: AbstractSRLogger, SRLogger, get_logger using .TemplateExpressionModule: TemplateExpression, TemplateStructure, TemplateExpressionSpec, ParamVector, has_params @@ -630,11 +633,10 @@ end @recorder record["options"] = "$(options)" nout = length(datasets) - example_dataset = first(datasets) - example_ex = create_expression(init_value(T), options, example_dataset) - NT = typeof(example_ex) - PopType = Population{T,L,NT} - HallOfFameType = HallOfFame{T,L,NT} + PMType = infer_popmember_type(T, L, D, options) + NT = expression_type(PMType) + PopType = Population{T,L,NT,PMType} + HallOfFameType = HallOfFame{T,L,NT,PMType} WorkerOutputType = get_worker_output_type( Val(ropt.parallelism), PopType, HallOfFameType ) @@ -692,9 +694,9 @@ end j in 1:nout ] - seed_members = [PopMember{T,L,NT}[] for j in 1:nout] + seed_members = [Vector{PMType}() for j in 1:nout] - return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; + return SearchState{T,L,NT,PMType,WorkerOutputType,ChannelType}(; procs=procs, we_created_procs=we_created_procs, worker_output=worker_output, @@ -810,10 +812,14 @@ function _preserve_loaded_state!( options::AbstractOptions, ) where {T,L,N} nout = length(state.worker_output) + # Get the prototype to extract types + prototype_pop = state.last_pops[1][1] + PopType = typeof(prototype_pop) + PM = popmember_type(PopType) + HallType = HallOfFame{T,L,N,PM} + for j in 1:nout, i in 1:(options.populations) - (pop, _, _, _) = extract_from_worker( - state.worker_output[j][i], Population{T,L,N}, HallOfFame{T,L,N} - ) + (pop, _, _, _) = extract_from_worker(state.worker_output[j][i], PopType, HallType) state.last_pops[j][i] = copy(pop) end return nothing @@ -843,11 +849,16 @@ function _warmup_search!( # Multi-threaded doesn't like to fetch within a new task: c_rss = deepcopy(running_search_statistics) last_pop = state.worker_output[j][i] + + # Get the prototype to extract types + prototype_pop = state.last_pops[j][i] + PopType = typeof(prototype_pop) + PM = popmember_type(PopType) + HallType = HallOfFame{T,L,N,PM} + updated_pop = @sr_spawner( begin - in_pop = first( - extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N}) - ) + in_pop = first(extract_from_worker(last_pop, PopType, HallType)) _dispatch_s_r_cycle( in_pop, dataset, diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 201c047e7..7562208c6 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -52,7 +52,7 @@ using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM -using ..PopMemberModule: PopMember +using ..PopMemberModule: PopMember, AbstractPopMember using ..ComposableExpressionModule: ComposableExpression, ValidVector struct ParamVector{T} <: AbstractVector{T} @@ -745,7 +745,7 @@ function MM.condition_mutation_weights!( @nospecialize(options::AbstractOptions), curmaxsize::Int, nfeatures::Int, -) where {T,L,N<:TemplateExpression,P<:PopMember{T,L,N}} +) where {T,L,N<:TemplateExpression,P<:AbstractPopMember{T,L,N}} if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 weights.break_connection = 0.0 @@ -828,7 +828,7 @@ end function MM.condition_mutate_constant!( ::Type{<:TemplateExpression}, weights::AbstractMutationWeights, - member::PopMember, + member::AbstractPopMember, options::AbstractOptions, curmaxsize::Int, ) diff --git a/test/runtests.jl b/test/runtests.jl index 7aef02fea..61e2362b6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -87,6 +87,8 @@ include("test_options.jl") include("test_hash.jl") end +include("test_hof_rows.jl") + @testitem "Test migration" tags = [:part3] begin include("test_migration.jl") end @@ -164,6 +166,7 @@ end end include("test_abstract_numbers.jl") +include("test_abstract_popmember.jl") include("test_logging.jl") include("test_pretty_printing.jl") diff --git a/test/test_abstract_popmember.jl b/test/test_abstract_popmember.jl new file mode 100644 index 000000000..eabc6886d --- /dev/null +++ b/test/test_abstract_popmember.jl @@ -0,0 +1,192 @@ +@testitem "Custom AbstractPopMember implementation" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions + using Test + using DispatchDoctor: @unstable + + import SymbolicRegression.PopMemberModule: create_child + + # Define a custom PopMember that tracks generation count + mutable struct CustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N} + tree::N + cost::L + loss::L + birth::Int + complexity::Int + ref::Int + parent::Int + generation::Int # Custom field to track generation + end + + # # Direct constructor that matches field order + function CustomPopMember( + tree::N, + cost::L, + loss::L, + birth::Int, + complexity::Int, + ref::Int, + parent::Int, + generation::Int, + ) where {T,L,N<:AbstractExpression{T}} + return CustomPopMember{T,L,N}( + tree, cost, loss, birth, complexity, ref, parent, generation + ) + end + + function CustomPopMember( + tree::N, + cost::L, + loss::L, + options, + complexity::Int; + parent=-1, + deterministic=nothing, + ) where {T,L,N<:AbstractExpression{T}} + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + 0, # Initial generation + ) + end + + # Constructor for Population initialization (dataset, tree, options) + function CustomPopMember( + dataset::SymbolicRegression.Dataset, tree, options; parent=-1, deterministic=nothing + ) + ex = SymbolicRegression.create_expression(tree, options, dataset) + complexity = SymbolicRegression.compute_complexity(ex, options) + cost, loss = SymbolicRegression.eval_cost( + dataset, ex, options; complexity=complexity + ) + + return CustomPopMember( + ex, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + 0, # Initial generation + ) + end + + @unstable DynamicExpressions.constructorof(::Type{<:CustomPopMember}) = CustomPopMember + + # Define with_type_parameters for CustomPopMember + @unstable function DynamicExpressions.with_type_parameters( + ::Type{<:CustomPopMember}, ::Type{T}, ::Type{L}, ::Type{N} + ) where {T,L,N} + return CustomPopMember{T,L,N} + end + + # Define copy for CustomPopMember + function Base.copy(p::CustomPopMember) + return CustomPopMember( + copy(p.tree), + copy(p.cost), + copy(p.loss), + copy(p.birth), + copy(getfield(p, :complexity)), + copy(p.ref), + copy(p.parent), + copy(p.generation), + ) + end + + function create_child( + parent::CustomPopMember{T,L}, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + ) where {T,L} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=options.deterministic), + actual_complexity, + abs(rand(Int)), + parent_ref, + parent.generation + 1, + ) + end + + function create_child( + parents::Tuple{<:CustomPopMember,<:CustomPopMember}, + tree::N, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + ) where {T,L,N<:AbstractExpression{T}} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + max_generation = max(parents[1].generation, parents[2].generation) + return CustomPopMember( + tree, + cost, + loss, + SymbolicRegression.CoreModule.UtilsModule.get_birth_order(; + deterministic=options.deterministic + ), + actual_complexity, + abs(rand(Int)), + parent_ref, + max_generation + 1, + ) + end + + # Test that we can run equation_search with CustomPopMember + X = randn(Float32, 2, 100) + y = @. X[1, :]^2 - X[2, :] + + options = SymbolicRegression.Options(; + binary_operators=[+, -], + populations=1, + population_size=20, + maxsize=5, + popmember_type=CustomPopMember, + deterministic=true, + seed=0, + ) + + # Test that options were created with correct type + @test options.popmember_type == CustomPopMember + + hall_of_fame = equation_search( + X, y; options=options, niterations=2, parallelism=:serial + ) + + # Verify that we got results + @test sum(hall_of_fame.exists) > 0 + + # Verify that the members are CustomPopMember + for i in eachindex(hall_of_fame.members, hall_of_fame.exists) + if hall_of_fame.exists[i] + @test hall_of_fame.members[i] isa CustomPopMember + # Check that generation field exists + @test hall_of_fame.members[i].generation >= 0 + end + end + + # Verify we can extract the best member + best_idx = findlast(hall_of_fame.exists) + @test !isnothing(best_idx) + best_member = hall_of_fame.members[best_idx] + @test best_member isa CustomPopMember +end diff --git a/test/test_hof_rows.jl b/test/test_hof_rows.jl new file mode 100644 index 000000000..5a038b31c --- /dev/null +++ b/test/test_hof_rows.jl @@ -0,0 +1,572 @@ +@testitem "HOF rows functionality" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions + using Test + + # Create test data + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; + binary_operators=[+, -], + unary_operators=[], + maxsize=5, + populations=1, + population_size=10, + tournament_selection_n=3, + deterministic=true, + seed=0, + ) + + dataset = Dataset(X, y) + + @testset "compute_scores" begin + # Create a simple HOF with multiple members + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + + # Add multiple members with different complexities + for i in 1:3 + hof.exists[i] = true + end + + members = [hof.members[i] for i in 1:3 if hof.exists[i]] + + # Test score computation + scores = SymbolicRegression.HallOfFameModule.compute_scores(members, options) + + @test length(scores) == length(members) + @test scores[1] == 0 # First member always has score 0 + @test all(s >= 0 for s in scores) # Scores should be non-negative + + # Test with empty members + empty_scores = SymbolicRegression.HallOfFameModule.compute_scores( + typeof(members[1])[], options + ) + @test isempty(empty_scores) + end + + @testset "HOFRows iteration" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + hof.exists[2] = true + + rows = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; pareto_only=false, include_score=true + ) + + # Test Base.length + @test length(rows) == 2 + + # Test Base.eltype + @test eltype(rows) == NamedTuple + + # Test iteration + collected = collect(rows) + @test length(collected) == 2 + @test all(r isa NamedTuple for r in collected) + + # Test that scores are included by default for pareto_only=true + @test all(haskey(r, :score) for r in collected) + + # Test equation inclusion + @test all(haskey(r, :equation) for r in collected) + end + + @testset "hof_rows options" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + for i in 1:3 + hof.exists[i] = true + end + + # Test pareto_only=false + rows_all = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; pareto_only=false + ) + # Should include all existing members (Pareto might filter some) + @test length(rows_all) == 3 + + # Test include_score=false + rows_no_score = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; include_score=false + ) + for row in rows_no_score + @test !haskey(row, :score) + end + end + + @testset "Empty HOF" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + # Don't mark any as existing + + rows = SymbolicRegression.HallOfFameModule.hof_rows(hof, dataset, options) + + @test length(rows) == 0 + @test isempty(collect(rows)) + end + + @testset "Backwards compatibility" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + hof.exists[2] = true + + # Test that format_hall_of_fame still works + formatted = SymbolicRegression.HallOfFameModule.format_hall_of_fame(hof, options) + + @test haskey(formatted, :trees) + @test haskey(formatted, :scores) + @test haskey(formatted, :losses) + @test haskey(formatted, :complexities) + @test length(formatted.trees) == length(formatted.scores) + @test length(formatted.trees) == length(formatted.losses) + @test length(formatted.trees) == length(formatted.complexities) + + # Test that string_dominating_pareto_curve still works + curve_string = SymbolicRegression.HallOfFameModule.string_dominating_pareto_curve( + hof, dataset, options + ) + + @test curve_string isa AbstractString + @test contains(curve_string, "Complexity") + @test contains(curve_string, "Loss") + end +end + +@testitem "HOF rows with custom PopMember" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions + using Test + using DispatchDoctor: @unstable + + import SymbolicRegression.PopMemberModule: create_child + import SymbolicRegression.HallOfFameModule: member_to_row + + # Define a custom PopMember with an extra field + mutable struct TestCustomPopMember{T,L,N} <: SymbolicRegression.AbstractPopMember{T,L,N} + tree::N + cost::L + loss::L + birth::Int + complexity::Int + ref::Int + parent::Int + custom_field::Float64 # Extra field + end + + # Constructor + function TestCustomPopMember( + tree::N, + cost::L, + loss::L, + birth::Int, + complexity::Int, + ref::Int, + parent::Int, + custom_field::Float64, + ) where {T,L,N<:AbstractExpression{T}} + return TestCustomPopMember{T,L,N}( + tree, cost, loss, birth, complexity, ref, parent, custom_field + ) + end + + function TestCustomPopMember( + tree::N, + cost::L, + loss::L, + options, + complexity::Int; + parent=-1, + deterministic=nothing, + custom_field=1.0, + ) where {T,L,N<:AbstractExpression{T}} + return TestCustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + custom_field, + ) + end + + function TestCustomPopMember( + dataset::SymbolicRegression.Dataset, + tree, + options; + parent=-1, + deterministic=nothing, + custom_field=1.0, + ) + ex = SymbolicRegression.create_expression(tree, options, dataset) + complexity = SymbolicRegression.compute_complexity(ex, options) + cost, loss = SymbolicRegression.eval_cost(dataset, ex, options; complexity) + + return TestCustomPopMember( + ex, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=deterministic), + complexity, + abs(rand(Int)), + parent, + custom_field, + ) + end + + @unstable DynamicExpressions.constructorof(::Type{<:TestCustomPopMember}) = + TestCustomPopMember + + @unstable function DynamicExpressions.with_type_parameters( + ::Type{<:TestCustomPopMember}, ::Type{T}, ::Type{L}, ::Type{N} + ) where {T,L,N} + return TestCustomPopMember{T,L,N} + end + + function Base.copy(p::TestCustomPopMember) + return TestCustomPopMember( + copy(p.tree), + copy(p.cost), + copy(p.loss), + copy(p.birth), + copy(getfield(p, :complexity)), + copy(p.ref), + copy(p.parent), + copy(p.custom_field), + ) + end + + function create_child( + parent::TestCustomPopMember{T,L}, + tree::AbstractExpression{T}, + cost::L, + loss::L, + options; + complexity::Union{Int,Nothing}=nothing, + parent_ref, + ) where {T,L} + actual_complexity = @something complexity SymbolicRegression.compute_complexity( + tree, options + ) + return TestCustomPopMember( + tree, + cost, + loss, + SymbolicRegression.get_birth_order(; deterministic=options.deterministic), + actual_complexity, + abs(rand(Int)), + parent_ref, + parent.custom_field * 1.1, # Modify custom field + ) + end + + # Extend member_to_row for custom PopMember + function member_to_row( + member::TestCustomPopMember, + dataset::SymbolicRegression.Dataset, + options::SymbolicRegression.AbstractOptions; + kwargs..., + ) + base = invoke( + member_to_row, + Tuple{ + SymbolicRegression.AbstractPopMember, + SymbolicRegression.Dataset, + SymbolicRegression.AbstractOptions, + }, + member, + dataset, + options; + kwargs..., + ) + return merge(base, (custom_field=member.custom_field,)) + end + + @testset "Custom PopMember with member_to_row extension" begin + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; + binary_operators=[+, -], + maxsize=5, + popmember_type=TestCustomPopMember, + deterministic=true, + seed=0, + ) + + dataset = Dataset(X, y) + + # Create a custom member + tree = SymbolicRegression.create_expression(1.0f0, options, dataset) + custom_member = TestCustomPopMember( + dataset, tree, options; deterministic=true, custom_field=42.0 + ) + + # Test that member_to_row includes custom field + row = member_to_row(custom_member, dataset, options) + + @test haskey(row, :custom_field) + @test row.custom_field == 42.0 + @test haskey(row, :complexity) + @test haskey(row, :loss) + @test haskey(row, :equation) + + # Test with HOF + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.members[1] = custom_member + hof.exists[1] = true + + rows = SymbolicRegression.HallOfFameModule.hof_rows(hof, dataset, options) + collected = collect(rows) + + @test length(collected) == 1 + @test haskey(collected[1], :custom_field) + @test collected[1].custom_field == 42.0 + end +end + +@testitem "Tables.jl extension" tags = [:part1] begin + using SymbolicRegression + using Test + + # Only run if Tables.jl is available + if isdefined(Base, :get_extension) + # Try to load Tables + try + @eval using Tables + + @testset "Tables.jl integration" begin + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; + binary_operators=[+, -], maxsize=5, deterministic=true, seed=0 + ) + + dataset = Dataset(X, y) + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + + rows = SymbolicRegression.HallOfFameModule.hof_rows(hof, dataset, options) + + # Test Tables.jl interface + @test Tables.istable(rows) + @test Tables.rowaccess(rows) + @test Tables.rows(rows) === rows # Should return itself + + # Test that it works with Tables.columntable + ct = Tables.columntable(rows) + @test ct isa NamedTuple + @test haskey(ct, :complexity) + @test haskey(ct, :loss) + end + catch e + @info "Skipping Tables.jl tests (Tables.jl not available): $e" + end + else + @info "Skipping Tables.jl tests (Julia version < 1.9)" + end +end + +@testitem "Column specifications" tags = [:part1] begin + using SymbolicRegression + using Test + using Printf: @sprintf + + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; binary_operators=[+, -], maxsize=5, deterministic=true, seed=0) + + dataset = Dataset(X, y) + + @testset "HOFColumn basics" begin + # Create a simple column + col = SymbolicRegression.HallOfFameModule.HOFColumn( + :loss, "Loss", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right + ) + + @test col.name == :loss + @test col.header == "Loss" + @test col.width == 8 + @test col.alignment == :right + + # Test getter and formatter + test_row = (loss=0.123456, complexity=5) + @test col.getter(test_row) == 0.123456 + @test col.formatter(0.123456) == "1.23e-01" + end + + @testset "default_columns" begin + # Test default columns without score (linear loss scale) + options_linear = Options(; + binary_operators=[+, -], maxsize=5, loss_scale=:linear, deterministic=true + ) + cols_linear = SymbolicRegression.HallOfFameModule.default_columns(options_linear) + + @test length(cols_linear) == 3 # complexity, loss, equation + @test cols_linear[1].name == :complexity + @test cols_linear[2].name == :loss + @test cols_linear[3].name == :equation + + # Test default columns with score (log loss scale) + options_log = Options(; + binary_operators=[+, -], maxsize=5, loss_scale=:log, deterministic=true + ) + cols_log = SymbolicRegression.HallOfFameModule.default_columns(options_log) + + @test length(cols_log) == 4 # complexity, loss, score, equation + @test cols_log[1].name == :complexity + @test cols_log[2].name == :loss + @test cols_log[3].name == :score + @test cols_log[4].name == :equation + end + + @testset "Custom columns with HOFRows" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + hof.exists[2] = true + + # Create custom column specs + custom_cols = [ + SymbolicRegression.HallOfFameModule.HOFColumn( + :complexity, "C", row -> row.complexity, string, 5, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :loss, "L", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right + ), + ] + + # Get rows with custom columns + rows = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; pareto_only=false, columns=custom_cols + ) + + # Collect and verify + collected = collect(rows) + @test length(collected) == 2 + + # Should only have the two specified columns + for row in collected + @test haskey(row, :complexity) + @test haskey(row, :loss) + @test !haskey(row, :equation) # Not requested + @test !haskey(row, :cost) # Not requested + end + end + + @testset "string_dominating_pareto_curve with custom columns" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + + # Test with default columns + str_default = SymbolicRegression.HallOfFameModule.string_dominating_pareto_curve( + hof, dataset, options + ) + @test str_default isa AbstractString + @test contains(str_default, "Complexity") + @test contains(str_default, "Loss") + + # Test with custom columns + custom_cols = [ + SymbolicRegression.HallOfFameModule.HOFColumn( + :complexity, "C", row -> row.complexity, string, 5, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :loss, "L", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :equation, "Eq", row -> row.equation, identity, nothing, :left + ), + ] + + str_custom = SymbolicRegression.HallOfFameModule.string_dominating_pareto_curve( + hof, dataset, options; columns=custom_cols + ) + @test str_custom isa AbstractString + @test contains(str_custom, "C") # Custom header + @test contains(str_custom, "L") # Custom header + @test !contains(str_custom, "Complexity") # Original header should not appear + end + + @testset "Computed columns" begin + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + + # Create a computed column (e.g., cost/loss ratio) + custom_cols = [ + SymbolicRegression.HallOfFameModule.HOFColumn( + :complexity, "C", row -> row.complexity, string, 5, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :ratio, + "Cost/Loss", + row -> row.cost / row.loss, # Computed from multiple fields + x -> @sprintf("%.2f", x), + 10, + :right, + ), + ] + + rows = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; pareto_only=false, columns=custom_cols + ) + + collected = collect(rows) + @test length(collected) == 1 + @test haskey(collected[1], :ratio) + @test collected[1].ratio isa Number + end +end + +@testitem "Column specs with Tables.jl" tags = [:part1] begin + using SymbolicRegression + using Test + using Printf: @sprintf + + # Only run if Tables.jl is available + if isdefined(Base, :get_extension) + try + @eval using Tables + + X = Float32[1.0 2.0 3.0; 4.0 5.0 6.0] + y = Float32[1.0, 2.0, 3.0] + + options = Options(; + binary_operators=[+, -], maxsize=5, deterministic=true, seed=0 + ) + + dataset = Dataset(X, y) + hof = SymbolicRegression.HallOfFameModule.HallOfFame(options, dataset) + hof.exists[1] = true + + @testset "Tables.jl with custom columns" begin + custom_cols = [ + SymbolicRegression.HallOfFameModule.HOFColumn( + :complexity, "C", row -> row.complexity, string, 5, :right + ), + SymbolicRegression.HallOfFameModule.HOFColumn( + :loss, "L", row -> row.loss, x -> @sprintf("%.2e", x), 8, :right + ), + ] + + rows = SymbolicRegression.HallOfFameModule.hof_rows( + hof, dataset, options; columns=custom_cols + ) + + # Test schema + schema = Tables.schema(rows) + @test schema !== nothing + @test schema.names == (:complexity, :loss) + + # Test columntable + ct = Tables.columntable(rows) + @test haskey(ct, :complexity) + @test haskey(ct, :loss) + @test !haskey(ct, :equation) # Not in custom columns + end + catch e + @info "Skipping Tables.jl column spec tests: $e" + end + end +end