Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1a8493a
fix: backport SymbolicUtils v4 compat to release-v1
MilesCranmerBot Feb 15, 2026
50179f3
test: avoid string matching for SymbolicUtils roundtrip
MilesCranmerBot Feb 15, 2026
eb77e9c
chore: format test_symbolic_utils.jl
MilesCranmerBot Feb 15, 2026
8a293e6
fix: use term() guards to match master behavior for commutative ops
MilesCranmerBot Feb 16, 2026
5fe645f
ci: add ext/ to workflow paths for fork validation
MilesCranmerBot Feb 16, 2026
4b21bf6
ci: trigger workflow
MilesCranmerBot Feb 16, 2026
f9a3fbb
fix(symbolicutils): handle atoms after multiply_powers
MilesCranmerBot Feb 17, 2026
ed9aeb0
SU v4: unwrap Const scalar containers + regression tests
MilesCranmerBot Feb 17, 2026
c6566ee
style: run JuliaFormatter
MilesCranmerBot Feb 17, 2026
5612a43
style: format (JuliaFormatter v1)
MilesCranmerBot Feb 17, 2026
7b7a00e
chore: drop CI-trigger / revert unrelated formatting
MilesCranmerBot Feb 17, 2026
d113eee
chore: revert unrelated changes (keep SUv4 patch minimal)
MilesCranmerBot Feb 18, 2026
224476b
test: use BasicSymbolic in simplification tests (SU v4)
MilesCranmerBot Feb 18, 2026
435db64
test: fix JET/Aqua and chainrules compatibility under SU v4
MilesCranmerBot Feb 18, 2026
08ad20e
test: ignore OperatorEnumConstruction in JET checks
MilesCranmerBot Feb 18, 2026
ca7ce41
fix: silence JET undef-var in precompile workload
MilesCranmerBot Feb 21, 2026
6e9093d
test: call bad_op via @eval in chainrules NaN regression
MilesCranmerBot Mar 1, 2026
3cd770f
test: inline @eval for bad_op calls in chainrules NaN regression
MilesCranmerBot Mar 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ MacroTools = "0.4, 0.5"
Optim = "0.19, 1"
PrecompileTools = "1"
Reexport = "1"
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
SymbolicUtils = "4"
Zygote = "0.7"
julia = "1.10"

Expand Down
135 changes: 80 additions & 55 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,14 @@ using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
using DynamicExpressions.UtilsModule: deprecate_varmap

using SymbolicUtils
using SymbolicUtils: BasicSymbolic, SymReal, iscall, issym, isconst, unwrap_const

import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
import DynamicExpressions.ValueInterfaceModule: is_valid

const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}
const SYMBOLIC_UTILS_TYPES = Union{<:Number,BasicSymbolic}
const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /)

@static if isdefined(SymbolicUtils, :iscall)
iscall(x) = SymbolicUtils.iscall(x)
else
iscall(x) = SymbolicUtils.istree(x)
end

macro return_on_false(flag, retval)
:(
if !$(esc(flag))
Expand All @@ -29,7 +24,7 @@ macro return_on_false(flag, retval)
)
end

function is_valid(x::SymbolicUtils.Symbolic)
function is_valid(x::BasicSymbolic)
return if iscall(x)
all(is_valid.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)]))
else
Expand All @@ -46,40 +41,53 @@ function parse_tree_to_eqs(
if tree.degree == 0
# Return constant if needed
tree.constant && return subs_bad(tree.val)
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
return SymbolicUtils.Sym{SymReal}(Symbol("x$(tree.feature)"); type=Number)
end
# Collect the next children
# TODO: Type instability!
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
# Get the operation
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
# Create an N tuple of Numbers for each argument
dtypes = map(x -> Number, 1:(tree.degree))
#
if !(op ∈ SUPPORTED_OPS) && index_functions
op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{dtypes...},Number}}(Symbol(op))

# For custom operators, SymbolicUtils can't represent the Julia function directly.
# When `index_functions=true`, represent the operator by its name as an uninterpreted
# SymbolicUtils function symbol so we can round-trip back to a DynamicExpressions node.
if !(op ∈ SUPPORTED_OPS)
if index_functions
dtypes = ntuple(_ -> Number, tree.degree)
op = SymbolicUtils.Sym{SymReal}(
Symbol(op); type=SymbolicUtils.FnType{Tuple{dtypes...},Number,Nothing}
)
else
error(
"Custom operator '$op' is not supported with SymbolicUtils unless " *
"index_functions=true. Supported operators without indexing: $SUPPORTED_OPS",
)
end
end

return subs_bad(
op(map(x -> parse_tree_to_eqs(x, operators, index_functions), children)...)
)
# Convert children to symbolic form
sym_children = map(x -> parse_tree_to_eqs(x, operators, index_functions), children)

return subs_bad(op(sym_children...))
end

# For operators which are indexed, we need to convert them back
# using the string:
function convert_to_function(
x::SymbolicUtils.Sym{SymbolicUtils.FnType{T,Number}}, operators::AbstractOperatorEnum
) where {T<:Tuple}
degree = length(T.types)
if degree == 1
ind = findoperation(x.name, operators.unaops)
return operators.unaops[ind]
elseif degree == 2
ind = findoperation(x.name, operators.binops)
return operators.binops[ind]
else
throw(AssertionError("Function $(String(x.name)) has degree > 2 !"))
function convert_to_function(x::BasicSymbolic, operators::AbstractOperatorEnum)
if issym(x) && SymbolicUtils.symtype(x) <: SymbolicUtils.FnType
signature, _ = SymbolicUtils.fntype_X_Y(SymbolicUtils.symtype(x))
degree = length(signature.parameters)
name = nameof(x)
if degree == 1
ind = findoperation(name, operators.unaops)
return operators.unaops[ind]
elseif degree == 2
ind = findoperation(name, operators.binops)
return operators.binops[ind]
else
throw(AssertionError("Function $(String(name)) has degree > 2 !"))
end
end
return x
end

# For normal operators, simply return the function itself:
Expand All @@ -90,7 +98,7 @@ function split_eq(
op,
args,
operators::AbstractOperatorEnum,
::Type{N}=Node;
(::Type{N})=Node;
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
# Deprecated:
varMap=nothing,
Expand Down Expand Up @@ -119,7 +127,7 @@ function findoperation(op, ops)
end

function Base.convert(
::typeof(SymbolicUtils.Symbolic),
::typeof(BasicSymbolic),
tree::Union{AbstractExpression,AbstractExpressionNode},
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
Expand All @@ -141,22 +149,32 @@ end

function Base.convert(
::Type{N},
expr::SymbolicUtils.Symbolic,
expr::BasicSymbolic,
operators::AbstractOperatorEnum;
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
) where {N<:AbstractExpressionNode}
variable_names = deprecate_varmap(variable_names, nothing, :convert)
if !iscall(expr)
# Handle constants (v4 wraps numbers in Const variant)
if isconst(expr)
return constructorof(N)(; val=DEFAULT_NODE_TYPE(unwrap_const(expr)))
end
# Handle symbols (variables)
if issym(expr)
exprname = nameof(expr)
if variable_names === nothing
s = String(expr.name)
s = String(exprname)
# Verify it is of the format "x{num}":
@assert(
occursin(r"^x\d+$", s),
"Variable name $s is not of the format x{num}. Please provide the `variable_names` explicitly."
)
return constructorof(N)(s)
end
return constructorof(N)(String(expr.name), variable_names)
return constructorof(N)(String(exprname), variable_names)
end
# Handle function calls
if !iscall(expr)
error("Unknown symbolic expression type: $(typeof(expr))")
end

# First, we remove integer powers:
Expand Down Expand Up @@ -185,7 +203,7 @@ _node_type(::Type{E}) where {E<:AbstractExpression} = default_node_type(E)

function Base.convert(
::Type{E},
x::Union{SymbolicUtils.Symbolic,Number},
x::Union{BasicSymbolic,Number},
operators::AbstractOperatorEnum;
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
kws...,
Expand All @@ -209,10 +227,9 @@ will generate a symbolic equation in SymbolicUtils.jl format.
- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
- `variable_names::Union{AbstractVector{<:AbstractString}, Nothing}=nothing`: What variable names to use for
each feature. Default is [x1, x2, x3, ...].
- `index_functions::Bool=false`: Whether to generate special names for the
operators, which then allows one to convert back to a `AbstractExpressionNode` format
using `symbolic_to_node`.
(CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
- `index_functions::Bool=false`: Whether to represent custom operators by name as
uninterpreted SymbolicUtils function symbols. This allows round-tripping back to a
`AbstractExpressionNode` using `symbolic_to_node`.
"""
function node_to_symbolic(
tree::AbstractExpressionNode,
Expand All @@ -231,8 +248,8 @@ function node_to_symbolic(
# Create a substitution tuple
subs = Dict(
[
SymbolicUtils.Sym{LiteralReal}(Symbol("x$(i)")) =>
SymbolicUtils.Sym{LiteralReal}(Symbol(variable_names[i])) for
SymbolicUtils.Sym{SymReal}(Symbol("x$(i)"); type=Number) =>
SymbolicUtils.Sym{SymReal}(Symbol(variable_names[i]); type=Number) for
i in 1:length(variable_names)
]...,
)
Expand All @@ -253,9 +270,9 @@ function node_to_symbolic(
end

function symbolic_to_node(
eqn::SymbolicUtils.Symbolic,
eqn::BasicSymbolic,
operators::AbstractOperatorEnum,
::Type{N}=Node;
(::Type{N})=Node;
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
# Deprecated:
varMap=nothing,
Expand All @@ -268,7 +285,7 @@ function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
return eqn, true
end

function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
function multiply_powers(eqn::BasicSymbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
if !iscall(eqn)
return eqn, true
end
Expand All @@ -277,7 +294,7 @@ function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPE
end

function multiply_powers(
eqn::SymbolicUtils.Symbolic, op::F
eqn::BasicSymbolic, op::F
)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F}
args = SymbolicUtils.arguments(eqn)
nargs = length(args)
Expand All @@ -291,15 +308,23 @@ function multiply_powers(
@return_on_false complete eqn
@return_on_false is_valid(l) eqn
n = args[2]
if typeof(n) <: Integer
if n == 1
# In SymbolicUtils v4, integer constants are wrapped in Const
n_val = if isconst(n)
unwrap_const(n)
elseif typeof(n) <: Integer
n
else
nothing
end
if n_val !== nothing && typeof(n_val) <: Integer
if n_val == 1
return l, true
elseif n == -1
elseif n_val == -1
return 1.0 / l, true
elseif n > 1
return reduce(*, [l for i in 1:n]), true
elseif n < -1
return reduce(/, vcat([1], [l for i in 1:abs(n)])), true
elseif n_val > 1
return reduce(*, [l for i in 1:n_val]), true
elseif n_val < -1
return reduce(/, vcat([1], [l for i in 1:abs(n_val)])), true
else
return 1.0, true
end
Expand Down
Loading
Loading