Skip to content

Commit 1a8493a

Browse files
fix: backport SymbolicUtils v4 compat to release-v1
Restore index_functions round-trip for custom operators under SymbolicUtils v4.
1 parent 55da8b0 commit 1a8493a

11 files changed

+317
-208
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ MacroTools = "0.4, 0.5"
3737
Optim = "0.19, 1"
3838
PrecompileTools = "1"
3939
Reexport = "1"
40-
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
40+
SymbolicUtils = "4"
4141
Zygote = "0.7"
4242
julia = "1.10"
4343

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 80 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,14 @@ using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
88
using DynamicExpressions.UtilsModule: deprecate_varmap
99

1010
using SymbolicUtils
11+
using SymbolicUtils: BasicSymbolic, SymReal, iscall, issym, isconst, unwrap_const
1112

1213
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
1314
import DynamicExpressions.ValueInterfaceModule: is_valid
1415

15-
const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}
16+
const SYMBOLIC_UTILS_TYPES = Union{<:Number,BasicSymbolic}
1617
const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /)
1718

18-
@static if isdefined(SymbolicUtils, :iscall)
19-
iscall(x) = SymbolicUtils.iscall(x)
20-
else
21-
iscall(x) = SymbolicUtils.istree(x)
22-
end
23-
2419
macro return_on_false(flag, retval)
2520
:(
2621
if !$(esc(flag))
@@ -29,7 +24,7 @@ macro return_on_false(flag, retval)
2924
)
3025
end
3126

32-
function is_valid(x::SymbolicUtils.Symbolic)
27+
function is_valid(x::BasicSymbolic)
3328
return if iscall(x)
3429
all(is_valid.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)]))
3530
else
@@ -46,40 +41,53 @@ function parse_tree_to_eqs(
4641
if tree.degree == 0
4742
# Return constant if needed
4843
tree.constant && return subs_bad(tree.val)
49-
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
44+
return SymbolicUtils.Sym{SymReal}(Symbol("x$(tree.feature)"); type=Number)
5045
end
5146
# Collect the next children
5247
# TODO: Type instability!
5348
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
5449
# Get the operation
5550
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
56-
# Create an N tuple of Numbers for each argument
57-
dtypes = map(x -> Number, 1:(tree.degree))
58-
#
59-
if !(op SUPPORTED_OPS) && index_functions
60-
op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{dtypes...},Number}}(Symbol(op))
51+
52+
# For custom operators, SymbolicUtils can't represent the Julia function directly.
53+
# When `index_functions=true`, represent the operator by its name as an uninterpreted
54+
# SymbolicUtils function symbol so we can round-trip back to a DynamicExpressions node.
55+
if !(op SUPPORTED_OPS)
56+
if index_functions
57+
dtypes = ntuple(_ -> Number, tree.degree)
58+
op = SymbolicUtils.Sym{SymReal}(
59+
Symbol(op); type=SymbolicUtils.FnType{Tuple{dtypes...},Number,Nothing}
60+
)
61+
else
62+
error(
63+
"Custom operator '$op' is not supported with SymbolicUtils unless " *
64+
"index_functions=true. Supported operators without indexing: $SUPPORTED_OPS",
65+
)
66+
end
6167
end
6268

63-
return subs_bad(
64-
op(map(x -> parse_tree_to_eqs(x, operators, index_functions), children)...)
65-
)
69+
# Convert children to symbolic form
70+
sym_children = map(x -> parse_tree_to_eqs(x, operators, index_functions), children)
71+
72+
return subs_bad(op(sym_children...))
6673
end
6774

68-
# For operators which are indexed, we need to convert them back
69-
# using the string:
70-
function convert_to_function(
71-
x::SymbolicUtils.Sym{SymbolicUtils.FnType{T,Number}}, operators::AbstractOperatorEnum
72-
) where {T<:Tuple}
73-
degree = length(T.types)
74-
if degree == 1
75-
ind = findoperation(x.name, operators.unaops)
76-
return operators.unaops[ind]
77-
elseif degree == 2
78-
ind = findoperation(x.name, operators.binops)
79-
return operators.binops[ind]
80-
else
81-
throw(AssertionError("Function $(String(x.name)) has degree > 2 !"))
75+
function convert_to_function(x::BasicSymbolic, operators::AbstractOperatorEnum)
76+
if issym(x) && SymbolicUtils.symtype(x) <: SymbolicUtils.FnType
77+
signature, _ = SymbolicUtils.fntype_X_Y(SymbolicUtils.symtype(x))
78+
degree = length(signature.parameters)
79+
name = nameof(x)
80+
if degree == 1
81+
ind = findoperation(name, operators.unaops)
82+
return operators.unaops[ind]
83+
elseif degree == 2
84+
ind = findoperation(name, operators.binops)
85+
return operators.binops[ind]
86+
else
87+
throw(AssertionError("Function $(String(name)) has degree > 2 !"))
88+
end
8289
end
90+
return x
8391
end
8492

8593
# For normal operators, simply return the function itself:
@@ -90,7 +98,7 @@ function split_eq(
9098
op,
9199
args,
92100
operators::AbstractOperatorEnum,
93-
::Type{N}=Node;
101+
(::Type{N})=Node;
94102
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
95103
# Deprecated:
96104
varMap=nothing,
@@ -119,7 +127,7 @@ function findoperation(op, ops)
119127
end
120128

121129
function Base.convert(
122-
::typeof(SymbolicUtils.Symbolic),
130+
::typeof(BasicSymbolic),
123131
tree::Union{AbstractExpression,AbstractExpressionNode},
124132
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
125133
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
@@ -141,22 +149,32 @@ end
141149

142150
function Base.convert(
143151
::Type{N},
144-
expr::SymbolicUtils.Symbolic,
152+
expr::BasicSymbolic,
145153
operators::AbstractOperatorEnum;
146154
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
147155
) where {N<:AbstractExpressionNode}
148156
variable_names = deprecate_varmap(variable_names, nothing, :convert)
149-
if !iscall(expr)
157+
# Handle constants (v4 wraps numbers in Const variant)
158+
if isconst(expr)
159+
return constructorof(N)(; val=DEFAULT_NODE_TYPE(unwrap_const(expr)))
160+
end
161+
# Handle symbols (variables)
162+
if issym(expr)
163+
exprname = nameof(expr)
150164
if variable_names === nothing
151-
s = String(expr.name)
165+
s = String(exprname)
152166
# Verify it is of the format "x{num}":
153167
@assert(
154168
occursin(r"^x\d+$", s),
155169
"Variable name $s is not of the format x{num}. Please provide the `variable_names` explicitly."
156170
)
157171
return constructorof(N)(s)
158172
end
159-
return constructorof(N)(String(expr.name), variable_names)
173+
return constructorof(N)(String(exprname), variable_names)
174+
end
175+
# Handle function calls
176+
if !iscall(expr)
177+
error("Unknown symbolic expression type: $(typeof(expr))")
160178
end
161179

162180
# First, we remove integer powers:
@@ -185,7 +203,7 @@ _node_type(::Type{E}) where {E<:AbstractExpression} = default_node_type(E)
185203

186204
function Base.convert(
187205
::Type{E},
188-
x::Union{SymbolicUtils.Symbolic,Number},
206+
x::Union{BasicSymbolic,Number},
189207
operators::AbstractOperatorEnum;
190208
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
191209
kws...,
@@ -209,10 +227,9 @@ will generate a symbolic equation in SymbolicUtils.jl format.
209227
- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
210228
- `variable_names::Union{AbstractVector{<:AbstractString}, Nothing}=nothing`: What variable names to use for
211229
each feature. Default is [x1, x2, x3, ...].
212-
- `index_functions::Bool=false`: Whether to generate special names for the
213-
operators, which then allows one to convert back to a `AbstractExpressionNode` format
214-
using `symbolic_to_node`.
215-
(CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
230+
- `index_functions::Bool=false`: Whether to represent custom operators by name as
231+
uninterpreted SymbolicUtils function symbols. This allows round-tripping back to a
232+
`AbstractExpressionNode` using `symbolic_to_node`.
216233
"""
217234
function node_to_symbolic(
218235
tree::AbstractExpressionNode,
@@ -231,8 +248,8 @@ function node_to_symbolic(
231248
# Create a substitution tuple
232249
subs = Dict(
233250
[
234-
SymbolicUtils.Sym{LiteralReal}(Symbol("x$(i)")) =>
235-
SymbolicUtils.Sym{LiteralReal}(Symbol(variable_names[i])) for
251+
SymbolicUtils.Sym{SymReal}(Symbol("x$(i)"); type=Number) =>
252+
SymbolicUtils.Sym{SymReal}(Symbol(variable_names[i]); type=Number) for
236253
i in 1:length(variable_names)
237254
]...,
238255
)
@@ -253,9 +270,9 @@ function node_to_symbolic(
253270
end
254271

255272
function symbolic_to_node(
256-
eqn::SymbolicUtils.Symbolic,
273+
eqn::BasicSymbolic,
257274
operators::AbstractOperatorEnum,
258-
::Type{N}=Node;
275+
(::Type{N})=Node;
259276
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
260277
# Deprecated:
261278
varMap=nothing,
@@ -268,7 +285,7 @@ function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
268285
return eqn, true
269286
end
270287

271-
function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
288+
function multiply_powers(eqn::BasicSymbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
272289
if !iscall(eqn)
273290
return eqn, true
274291
end
@@ -277,7 +294,7 @@ function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPE
277294
end
278295

279296
function multiply_powers(
280-
eqn::SymbolicUtils.Symbolic, op::F
297+
eqn::BasicSymbolic, op::F
281298
)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F}
282299
args = SymbolicUtils.arguments(eqn)
283300
nargs = length(args)
@@ -291,15 +308,23 @@ function multiply_powers(
291308
@return_on_false complete eqn
292309
@return_on_false is_valid(l) eqn
293310
n = args[2]
294-
if typeof(n) <: Integer
295-
if n == 1
311+
# In SymbolicUtils v4, integer constants are wrapped in Const
312+
n_val = if isconst(n)
313+
unwrap_const(n)
314+
elseif typeof(n) <: Integer
315+
n
316+
else
317+
nothing
318+
end
319+
if n_val !== nothing && typeof(n_val) <: Integer
320+
if n_val == 1
296321
return l, true
297-
elseif n == -1
322+
elseif n_val == -1
298323
return 1.0 / l, true
299-
elseif n > 1
300-
return reduce(*, [l for i in 1:n]), true
301-
elseif n < -1
302-
return reduce(/, vcat([1], [l for i in 1:abs(n)])), true
324+
elseif n_val > 1
325+
return reduce(*, [l for i in 1:n_val]), true
326+
elseif n_val < -1
327+
return reduce(/, vcat([1], [l for i in 1:abs(n_val)])), true
303328
else
304329
return 1.0, true
305330
end

0 commit comments

Comments
 (0)