Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[extensions]
SymbolicRegressionEnzymeExt = "Enzyme"
SymbolicRegressionJSON3Ext = "JSON3"
SymbolicRegressionMooncakeExt = "Mooncake"
SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"
SymbolicRegressionTablesExt = "Tables"

[compat]
ADTypes = "^1.4.0"
Expand Down Expand Up @@ -74,4 +76,5 @@ StatsBase = "0.33, 0.34"
StyledStrings = "1"
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
TOML = "<0.0.1, 1"
Tables = "1"
julia = "1.10"
46 changes: 46 additions & 0 deletions ext/SymbolicRegressionTablesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module SymbolicRegressionTablesExt

using Tables: Tables
import SymbolicRegression.HallOfFameModule: HOFRows, member_to_row

# Make HOFRows compatible with the Tables.jl interface
# HOFRows is already iterable via Base.iterate, so we just need to declare compatibility
Tables.istable(::Type{<:HOFRows}) = true
Tables.rowaccess(::Type{<:HOFRows}) = true
Tables.rows(view::HOFRows) = view # Return itself since it's already iterable

# Provide schema information for better Tables.jl integration
function Tables.schema(rows::HOFRows)
if isempty(rows.members)
# Empty table - can't infer schema
return nothing
end

# Get column names from either column specs or first row
if rows.columns !== nothing
# Use explicit column specs
names = Tuple(col.name for col in rows.columns)
# We can't reliably infer types without evaluating, so return nothing for types
return Tables.Schema(names, nothing)
else
# Infer from first row
first_row = member_to_row(
rows.members[1], rows.dataset, rows.options; pretty=rows.pretty
)
if rows.include_score
# Will add score in iteration
names = (keys(first_row)..., :score)
else
names = keys(first_row)
end
# Get types from first row
types = if rows.include_score
(typeof.(values(first_row))..., Float64) # Assume Float64 for score
else
typeof.(values(first_row))
end
return Tables.Schema(names, types)
end
end

end
8 changes: 4 additions & 4 deletions src/ConstantOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using ..CoreModule:
AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options, dataset_fraction
using ..UtilsModule: get_birth_order, PerTaskCache, stable_get!
using ..LossFunctionsModule: eval_loss, loss_to_cost
using ..PopMemberModule: PopMember
using ..PopMemberModule: AbstractPopMember, PopMember

function can_optimize(::AbstractExpression{T}, options) where {T}
return can_optimize(T, options)
Expand All @@ -31,7 +31,7 @@ end
member::P,
options::AbstractOptions;
rng::AbstractRNG=default_rng(),
)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,N,P<:AbstractPopMember{T,L,N}}
can_optimize(member.tree, options) || return (member, 0.0)
nconst = count_constants_for_optimization(member.tree)
nconst == 0 && return (member, 0.0)
Expand Down Expand Up @@ -63,7 +63,7 @@ count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex)

function _optimize_constants(
dataset, member::P, options, algorithm, optimizer_options, rng
)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}}
)::Tuple{P,Float64} where {T,L,N,P<:AbstractPopMember{T,L,N}}
tree = member.tree
x0, refs = get_scalar_constants(tree)
@assert count_constants_for_optimization(tree) == length(x0)
Expand All @@ -76,7 +76,7 @@ function _optimize_constants(
end
function _optimize_constants_inner(
f::F, fg!::G, x0, refs, dataset, member::P, options, algorithm, optimizer_options, rng
)::Tuple{P,Float64} where {F,G,T,L,P<:PopMember{T,L}}
)::Tuple{P,Float64} where {F,G,T,L,N,P<:AbstractPopMember{T,L,N}}
obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing
f
else
Expand Down
19 changes: 10 additions & 9 deletions src/ExpressionBuilder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ using DynamicExpressions:
using ..CoreModule: AbstractOptions, Dataset
using ..HallOfFameModule: HallOfFame
using ..PopulationModule: Population
using ..PopMemberModule: PopMember
using ..PopMemberModule: PopMember, AbstractPopMember, create_child
using ..ComplexityModule: compute_complexity

import DynamicExpressions: get_operators
import ..CoreModule: create_expression
Expand Down Expand Up @@ -107,16 +108,16 @@ end
return with_metadata(ex; init_params(options, dataset, ex, Val(true))...)
end
function embed_metadata(
member::PopMember, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return PopMember(
member::PM, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L,N,PM<:AbstractPopMember{T,L,N}}
return create_child(
member,
embed_metadata(member.tree, options, dataset),
member.cost,
member.loss,
nothing;
member.ref,
member.parent,
deterministic=options.deterministic,
options;
complexity=compute_complexity(member, options),
parent_ref=member.ref,
)
end
function embed_metadata(
Expand All @@ -135,7 +136,7 @@ end
end
function embed_metadata(
vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L,H<:Union{HallOfFame,Population,PopMember}}
) where {T,L,H<:Union{HallOfFame,Population,AbstractPopMember}}
return map(Fix{2}(Fix{3}(embed_metadata, dataset), options), vec)
end
end
Expand Down
Loading
Loading