Skip to content

Commit 82c4fe3

Browse files
committed
feat: add SymbolicUtils v4 support (drop pre-v4)
Update SymbolicUtils extension to support only v4+, removing backward compatibility code for older versions. Key v4 API changes handled: - Use BasicSymbolic instead of Symbolic - Use TreeReal instead of LiteralReal - Use nameof() instead of .name field access - Sym constructor requires type= keyword argument - Handle wrapped constants via isconst/unwrap_const - Type parameter is SymVariant, not Number Breaking: index_functions=true is not supported (was already documented as "CURRENTLY UNAVAILABLE"). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent d4b12df commit 82c4fe3

File tree

2 files changed

+31
-42
lines changed

2 files changed

+31
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ MacroTools = "0.4, 0.5"
3939
Optim = "0.19, 1"
4040
PrecompileTools = "1"
4141
Reexport = "1"
42-
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
42+
SymbolicUtils = "4"
4343
Zygote = "0.7"
4444
julia = "1.10"
4545

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,14 @@ using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
88
using DynamicExpressions.UtilsModule: deprecate_varmap
99

1010
using SymbolicUtils
11+
using SymbolicUtils: BasicSymbolic, TreeReal, iscall, issym, isconst, unwrap_const
1112

1213
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
1314
import DynamicExpressions.ValueInterfaceModule: is_valid
1415

15-
const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}
16+
const SYMBOLIC_UTILS_TYPES = Union{<:Number,BasicSymbolic}
1617
const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /)
1718

18-
@static if isdefined(SymbolicUtils, :iscall)
19-
iscall(x) = SymbolicUtils.iscall(x)
20-
else
21-
iscall(x) = SymbolicUtils.istree(x)
22-
end
23-
2419
macro return_on_false(flag, retval)
2520
:(
2621
if !$(esc(flag))
@@ -29,7 +24,7 @@ macro return_on_false(flag, retval)
2924
)
3025
end
3126

32-
function is_valid(x::SymbolicUtils.Symbolic)
27+
function is_valid(x::BasicSymbolic)
3328
return if iscall(x)
3429
all(is_valid.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)]))
3530
else
@@ -46,43 +41,27 @@ function parse_tree_to_eqs(
4641
if tree.degree == 0
4742
# Return constant if needed
4843
tree.constant && return subs_bad(tree.val)
49-
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
44+
return SymbolicUtils.Sym{TreeReal}(Symbol("x$(tree.feature)"); type=Number)
5045
end
5146
# Collect the next children
5247
# TODO: Type instability!
5348
children =
5449
tree.degree == 2 ? (get_child(tree, 1), get_child(tree, 2)) : (get_child(tree, 1),)
5550
# Get the operation
5651
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
57-
# Create an N tuple of Numbers for each argument
58-
dtypes = map(x -> Number, 1:(tree.degree))
5952
#
6053
if !(op SUPPORTED_OPS) && index_functions
61-
op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{dtypes...},Number}}(Symbol(op))
54+
error(
55+
"index_functions=true is not supported with SymbolicUtils v4+. " *
56+
"Custom operator '$op' cannot be converted to a symbolic function.",
57+
)
6258
end
6359

6460
return subs_bad(
6561
op(map(x -> parse_tree_to_eqs(x, operators, index_functions), children)...)
6662
)
6763
end
6864

69-
# For operators which are indexed, we need to convert them back
70-
# using the string:
71-
function convert_to_function(
72-
x::SymbolicUtils.Sym{SymbolicUtils.FnType{T,Number}}, operators::AbstractOperatorEnum
73-
) where {T<:Tuple}
74-
degree = length(T.types)
75-
if degree == 1
76-
ind = findoperation(x.name, operators.unaops)
77-
return operators.unaops[ind]
78-
elseif degree == 2
79-
ind = findoperation(x.name, operators.binops)
80-
return operators.binops[ind]
81-
else
82-
throw(AssertionError("Function $(String(x.name)) has degree > 2 !"))
83-
end
84-
end
85-
8665
# For normal operators, simply return the function itself:
8766
convert_to_function(x, operators::AbstractOperatorEnum) = x
8867

@@ -120,7 +99,7 @@ function findoperation(op, ops)
12099
end
121100

122101
function Base.convert(
123-
::typeof(SymbolicUtils.Symbolic),
102+
::typeof(BasicSymbolic),
124103
tree::Union{AbstractExpression,AbstractExpressionNode},
125104
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
126105
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
@@ -142,22 +121,32 @@ end
142121

143122
function Base.convert(
144123
::Type{N},
145-
expr::SymbolicUtils.Symbolic,
124+
expr::BasicSymbolic,
146125
operators::AbstractOperatorEnum;
147126
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
148127
) where {N<:AbstractExpressionNode}
149128
variable_names = deprecate_varmap(variable_names, nothing, :convert)
150-
if !iscall(expr)
129+
# Handle constants (v4 wraps numbers in Const variant)
130+
if isconst(expr)
131+
return constructorof(N)(; val=DEFAULT_NODE_TYPE(unwrap_const(expr)))
132+
end
133+
# Handle symbols (variables)
134+
if issym(expr)
135+
exprname = nameof(expr)
151136
if variable_names === nothing
152-
s = String(expr.name)
137+
s = String(exprname)
153138
# Verify it is of the format "x{num}":
154139
@assert(
155140
occursin(r"^x\d+$", s),
156141
"Variable name $s is not of the format x{num}. Please provide the `variable_names` explicitly."
157142
)
158143
return constructorof(N)(s)
159144
end
160-
return constructorof(N)(String(expr.name), variable_names)
145+
return constructorof(N)(String(exprname), variable_names)
146+
end
147+
# Handle function calls
148+
if !iscall(expr)
149+
error("Unknown symbolic expression type: $(typeof(expr))")
161150
end
162151

163152
# First, we remove integer powers:
@@ -190,7 +179,7 @@ _node_type(::Type{E}) where {E<:AbstractExpression} = default_node_type(E)
190179

191180
function Base.convert(
192181
::Type{E},
193-
x::Union{SymbolicUtils.Symbolic,Number},
182+
x::Union{BasicSymbolic,Number},
194183
operators::AbstractOperatorEnum;
195184
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
196185
kws...,
@@ -217,7 +206,7 @@ will generate a symbolic equation in SymbolicUtils.jl format.
217206
- `index_functions::Bool=false`: Whether to generate special names for the
218207
operators, which then allows one to convert back to a `AbstractExpressionNode` format
219208
using `symbolic_to_node`.
220-
(CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
209+
(CURRENTLY UNAVAILABLE).
221210
"""
222211
function node_to_symbolic(
223212
tree::AbstractExpressionNode{T,2},
@@ -236,8 +225,8 @@ function node_to_symbolic(
236225
# Create a substitution tuple
237226
subs = Dict(
238227
[
239-
SymbolicUtils.Sym{LiteralReal}(Symbol("x$(i)")) =>
240-
SymbolicUtils.Sym{LiteralReal}(Symbol(variable_names[i])) for
228+
SymbolicUtils.Sym{TreeReal}(Symbol("x$(i)"); type=Number) =>
229+
SymbolicUtils.Sym{TreeReal}(Symbol(variable_names[i]); type=Number) for
241230
i in 1:length(variable_names)
242231
]...,
243232
)
@@ -258,7 +247,7 @@ function node_to_symbolic(
258247
end
259248

260249
function symbolic_to_node(
261-
eqn::SymbolicUtils.Symbolic,
250+
eqn::BasicSymbolic,
262251
operators::AbstractOperatorEnum,
263252
::Type{N}=Node;
264253
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
@@ -273,7 +262,7 @@ function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
273262
return eqn, true
274263
end
275264

276-
function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
265+
function multiply_powers(eqn::BasicSymbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
277266
if !iscall(eqn)
278267
return eqn, true
279268
end
@@ -282,7 +271,7 @@ function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPE
282271
end
283272

284273
function multiply_powers(
285-
eqn::SymbolicUtils.Symbolic, op::F
274+
eqn::BasicSymbolic, op::F
286275
)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F}
287276
args = SymbolicUtils.arguments(eqn)
288277
nargs = length(args)

0 commit comments

Comments
 (0)