@@ -8,19 +8,14 @@ using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
88using DynamicExpressions. UtilsModule: deprecate_varmap
99
1010using SymbolicUtils
11+ using SymbolicUtils: BasicSymbolic, TreeReal, iscall, issym, isconst, unwrap_const
1112
1213import DynamicExpressions. ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
1314import DynamicExpressions. ValueInterfaceModule: is_valid
1415
15- const SYMBOLIC_UTILS_TYPES = Union{<: Number ,SymbolicUtils . Symbolic{ <: Number } }
16+ const SYMBOLIC_UTILS_TYPES = Union{<: Number ,BasicSymbolic }
1617const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, + , - , * , / )
1718
18- @static if isdefined (SymbolicUtils, :iscall )
19- iscall (x) = SymbolicUtils. iscall (x)
20- else
21- iscall (x) = SymbolicUtils. istree (x)
22- end
23-
2419macro return_on_false (flag, retval)
2520 :(
2621 if ! $ (esc (flag))
@@ -29,7 +24,7 @@ macro return_on_false(flag, retval)
2924 )
3025end
3126
32- function is_valid (x:: SymbolicUtils.Symbolic )
27+ function is_valid (x:: BasicSymbolic )
3328 return if iscall (x)
3429 all (is_valid .([SymbolicUtils. operation (x); SymbolicUtils. arguments (x)]))
3530 else
@@ -46,43 +41,27 @@ function parse_tree_to_eqs(
4641 if tree. degree == 0
4742 # Return constant if needed
4843 tree. constant && return subs_bad (tree. val)
49- return SymbolicUtils. Sym {LiteralReal } (Symbol (" x$(tree. feature) " ))
44+ return SymbolicUtils. Sym {TreeReal } (Symbol (" x$(tree. feature) " ); type = Number )
5045 end
5146 # Collect the next children
5247 # TODO : Type instability!
5348 children =
5449 tree. degree == 2 ? (get_child (tree, 1 ), get_child (tree, 2 )) : (get_child (tree, 1 ),)
5550 # Get the operation
5651 op = tree. degree == 2 ? operators. binops[tree. op] : operators. unaops[tree. op]
57- # Create an N tuple of Numbers for each argument
58- dtypes = map (x -> Number, 1 : (tree. degree))
5952 #
6053 if ! (op ∈ SUPPORTED_OPS) && index_functions
61- op = SymbolicUtils. Sym {(SymbolicUtils.FnType){Tuple{dtypes...},Number}} (Symbol (op))
54+ error (
55+ " index_functions=true is not supported with SymbolicUtils v4+. " *
56+ " Custom operator '$op ' cannot be converted to a symbolic function." ,
57+ )
6258 end
6359
6460 return subs_bad (
6561 op (map (x -> parse_tree_to_eqs (x, operators, index_functions), children)... )
6662 )
6763end
6864
69- # For operators which are indexed, we need to convert them back
70- # using the string:
71- function convert_to_function (
72- x:: SymbolicUtils.Sym{SymbolicUtils.FnType{T,Number}} , operators:: AbstractOperatorEnum
73- ) where {T<: Tuple }
74- degree = length (T. types)
75- if degree == 1
76- ind = findoperation (x. name, operators. unaops)
77- return operators. unaops[ind]
78- elseif degree == 2
79- ind = findoperation (x. name, operators. binops)
80- return operators. binops[ind]
81- else
82- throw (AssertionError (" Function $(String (x. name)) has degree > 2 !" ))
83- end
84- end
85-
8665# For normal operators, simply return the function itself:
8766convert_to_function (x, operators:: AbstractOperatorEnum ) = x
8867
@@ -120,7 +99,7 @@ function findoperation(op, ops)
12099end
121100
122101function Base. convert (
123- :: typeof (SymbolicUtils . Symbolic ),
102+ :: typeof (BasicSymbolic ),
124103 tree:: Union{AbstractExpression,AbstractExpressionNode} ,
125104 operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
126105 variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
@@ -142,22 +121,32 @@ end
142121
143122function Base. convert (
144123 :: Type{N} ,
145- expr:: SymbolicUtils.Symbolic ,
124+ expr:: BasicSymbolic ,
146125 operators:: AbstractOperatorEnum ;
147126 variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
148127) where {N<: AbstractExpressionNode }
149128 variable_names = deprecate_varmap (variable_names, nothing , :convert )
150- if ! iscall (expr)
129+ # Handle constants (v4 wraps numbers in Const variant)
130+ if isconst (expr)
131+ return constructorof (N)(; val= DEFAULT_NODE_TYPE (unwrap_const (expr)))
132+ end
133+ # Handle symbols (variables)
134+ if issym (expr)
135+ exprname = nameof (expr)
151136 if variable_names === nothing
152- s = String (expr . name )
137+ s = String (exprname )
153138 # Verify it is of the format "x{num}":
154139 @assert (
155140 occursin (r" ^x\d +$" , s),
156141 " Variable name $s is not of the format x{num}. Please provide the `variable_names` explicitly."
157142 )
158143 return constructorof (N)(s)
159144 end
160- return constructorof (N)(String (expr. name), variable_names)
145+ return constructorof (N)(String (exprname), variable_names)
146+ end
147+ # Handle function calls
148+ if ! iscall (expr)
149+ error (" Unknown symbolic expression type: $(typeof (expr)) " )
161150 end
162151
163152 # First, we remove integer powers:
@@ -190,7 +179,7 @@ _node_type(::Type{E}) where {E<:AbstractExpression} = default_node_type(E)
190179
191180function Base. convert (
192181 :: Type{E} ,
193- x:: Union{SymbolicUtils.Symbolic ,Number} ,
182+ x:: Union{BasicSymbolic ,Number} ,
194183 operators:: AbstractOperatorEnum ;
195184 variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
196185 kws... ,
@@ -217,7 +206,7 @@ will generate a symbolic equation in SymbolicUtils.jl format.
217206- `index_functions::Bool=false`: Whether to generate special names for the
218207 operators, which then allows one to convert back to a `AbstractExpressionNode` format
219208 using `symbolic_to_node`.
220- (CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84 ).
209+ (CURRENTLY UNAVAILABLE).
221210"""
222211function node_to_symbolic (
223212 tree:: AbstractExpressionNode{T,2} ,
@@ -236,8 +225,8 @@ function node_to_symbolic(
236225 # Create a substitution tuple
237226 subs = Dict (
238227 [
239- SymbolicUtils. Sym {LiteralReal } (Symbol (" x$(i) " )) =>
240- SymbolicUtils. Sym {LiteralReal } (Symbol (variable_names[i])) for
228+ SymbolicUtils. Sym {TreeReal } (Symbol (" x$(i) " ); type = Number ) =>
229+ SymbolicUtils. Sym {TreeReal } (Symbol (variable_names[i]); type = Number ) for
241230 i in 1 : length (variable_names)
242231 ]. .. ,
243232 )
@@ -258,7 +247,7 @@ function node_to_symbolic(
258247end
259248
260249function symbolic_to_node (
261- eqn:: SymbolicUtils.Symbolic ,
250+ eqn:: BasicSymbolic ,
262251 operators:: AbstractOperatorEnum ,
263252 :: Type{N} = Node;
264253 variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
@@ -273,7 +262,7 @@ function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
273262 return eqn, true
274263end
275264
276- function multiply_powers (eqn:: SymbolicUtils.Symbolic ):: Tuple{SYMBOLIC_UTILS_TYPES,Bool}
265+ function multiply_powers (eqn:: BasicSymbolic ):: Tuple{SYMBOLIC_UTILS_TYPES,Bool}
277266 if ! iscall (eqn)
278267 return eqn, true
279268 end
@@ -282,7 +271,7 @@ function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPE
282271end
283272
284273function multiply_powers (
285- eqn:: SymbolicUtils.Symbolic , op:: F
274+ eqn:: BasicSymbolic , op:: F
286275):: Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F}
287276 args = SymbolicUtils. arguments (eqn)
288277 nargs = length (args)
0 commit comments