11module DynamicExpressionsSymbolicUtilsExt
22
3- using SymbolicUtils
4- import DynamicExpressions. NodeModule:
3+ using DynamicExpressions:
4+ AbstractExpression, get_tree, get_operators, get_variable_names, default_node_type
5+ using DynamicExpressions. NodeModule:
56 AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
6- import DynamicExpressions. OperatorEnumModule: AbstractOperatorEnum
7- import DynamicExpressions. ValueInterfaceModule: is_valid
8- import DynamicExpressions. UtilsModule: deprecate_varmap
7+ using DynamicExpressions. OperatorEnumModule: AbstractOperatorEnum
8+ using DynamicExpressions. UtilsModule: deprecate_varmap
9+
10+ using SymbolicUtils
11+
912import DynamicExpressions. ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
10- import DynamicExpressions: AbstractExpression, get_tree, get_operators
13+ import DynamicExpressions. ValueInterfaceModule : is_valid
1114
1215const SYMBOLIC_UTILS_TYPES = Union{<: Number ,SymbolicUtils. Symbolic{<: Number }}
1316const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, + , - , * , / )
@@ -88,7 +91,7 @@ function split_eq(
8891 args,
8992 operators:: AbstractOperatorEnum ,
9093 :: Type{N} = Node;
91- variable_names:: Union{Array{String,1 },Nothing} = nothing ,
94+ variable_names:: Union{AbstractVector{<:AbstractString },Nothing} = nothing ,
9295 # Deprecated:
9396 varMap= nothing ,
9497) where {N<: AbstractExpressionNode }
@@ -103,8 +106,8 @@ function split_eq(
103106 end
104107 return constructorof (N)(;
105108 op= ind,
106- l= convert (N, args[1 ], operators; variable_names= variable_names ),
107- r= convert (N, op (args[2 : end ]. .. ), operators; variable_names= variable_names ),
109+ l= convert (N, args[1 ], operators; variable_names),
110+ r= convert (N, op (args[2 : end ]. .. ), operators; variable_names),
108111 )
109112end
110113
@@ -119,17 +122,14 @@ function Base.convert(
119122 :: typeof (SymbolicUtils. Symbolic),
120123 tree:: Union{AbstractExpression,AbstractExpressionNode} ,
121124 operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
122- variable_names:: Union{Array{String,1 },Nothing} = nothing ,
125+ variable_names:: Union{AbstractVector{<:AbstractString },Nothing} = nothing ,
123126 index_functions:: Bool = false ,
124127 # Deprecated:
125128 varMap= nothing ,
126129)
127130 variable_names = deprecate_varmap (variable_names, varMap, :convert )
128131 return node_to_symbolic (
129- tree,
130- get_operators (tree, operators);
131- variable_names= variable_names,
132- index_functions= index_functions,
132+ tree, get_operators (tree, operators); variable_names, index_functions
133133 )
134134end
135135
@@ -143,11 +143,19 @@ function Base.convert(
143143 :: Type{N} ,
144144 expr:: SymbolicUtils.Symbolic ,
145145 operators:: AbstractOperatorEnum ;
146- variable_names:: Union{Array{String,1 },Nothing} = nothing ,
146+ variable_names:: Union{AbstractVector{<:AbstractString },Nothing} = nothing ,
147147) where {N<: AbstractExpressionNode }
148148 variable_names = deprecate_varmap (variable_names, nothing , :convert )
149149 if ! iscall (expr)
150- variable_names === nothing && return constructorof (N)(String (expr. name))
150+ if variable_names === nothing
151+ s = String (expr. name)
152+ # Verify it is of the format "x{num}":
153+ @assert (
154+ occursin (r" ^x\d +$" , s),
155+ " Variable name $s is not of the format x{num}. Please provide the `variable_names` explicitly."
156+ )
157+ return constructorof (N)(s)
158+ end
151159 return constructorof (N)(String (expr. name), variable_names)
152160 end
153161
@@ -160,23 +168,36 @@ function Base.convert(
160168 op = convert_to_function (SymbolicUtils. operation (expr), operators)
161169 args = SymbolicUtils. arguments (expr)
162170
163- length (args) > 2 &&
164- return split_eq (op, args, operators, N; variable_names= variable_names)
171+ length (args) > 2 && return split_eq (op, args, operators, N; variable_names)
165172 ind = if length (args) == 2
166173 findoperation (op, operators. binops)
167174 else
168175 findoperation (op, operators. unaops)
169176 end
170177
171178 return constructorof (N)(;
172- op= ind,
173- children= map (x -> convert (N, x, operators; variable_names= variable_names), args),
179+ op= ind, children= map (x -> convert (N, x, operators; variable_names), args)
174180 )
175181end
176182
183+ _node_type (:: Type{<:AbstractExpression{T,N}} ) where {T,N<: AbstractExpressionNode } = N
184+ _node_type (:: Type{E} ) where {E<: AbstractExpression } = default_node_type (E)
185+
186+ function Base. convert (
187+ :: Type{E} ,
188+ x:: Union{SymbolicUtils.Symbolic,Number} ,
189+ operators:: AbstractOperatorEnum ;
190+ variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
191+ kws... ,
192+ ) where {E<: AbstractExpression }
193+ N = _node_type (E)
194+ tree = convert (N, x, operators; variable_names)
195+ return constructorof (E)(tree; operators, variable_names, kws... )
196+ end
197+
177198"""
178199 node_to_symbolic(tree::AbstractExpressionNode, operators::AbstractOperatorEnum;
179- variable_names::Union{Array{String, 1 }, Nothing}=nothing,
200+ variable_names::Union{AbstractVector{<:AbstractString }, Nothing}=nothing,
180201 index_functions::Bool=false)
181202
182203The interface to SymbolicUtils.jl. Passing a tree to this function
@@ -186,7 +207,7 @@ will generate a symbolic equation in SymbolicUtils.jl format.
186207
187208- `tree::AbstractExpressionNode`: The equation to convert.
188209- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
189- - `variable_names::Union{Array{String, 1 }, Nothing}=nothing`: What variable names to use for
210+ - `variable_names::Union{AbstractVector{<:AbstractString }, Nothing}=nothing`: What variable names to use for
190211 each feature. Default is [x1, x2, x3, ...].
191212- `index_functions::Bool=false`: Whether to generate special names for the
192213 operators, which then allows one to convert back to a `AbstractExpressionNode` format
@@ -196,7 +217,7 @@ will generate a symbolic equation in SymbolicUtils.jl format.
196217function node_to_symbolic (
197218 tree:: AbstractExpressionNode ,
198219 operators:: AbstractOperatorEnum ;
199- variable_names:: Union{Array{String,1 },Nothing} = nothing ,
220+ variable_names:: Union{AbstractVector{<:AbstractString },Nothing} = nothing ,
200221 index_functions:: Bool = false ,
201222 # Deprecated:
202223 varMap= nothing ,
@@ -218,16 +239,24 @@ function node_to_symbolic(
218239 return substitute (expr, subs)
219240end
220241function node_to_symbolic (
221- tree:: AbstractExpression , operators:: Union{AbstractOperatorEnum,Nothing} = nothing ; kws...
242+ tree:: AbstractExpression ,
243+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
244+ variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
245+ kws... ,
222246)
223- return node_to_symbolic (get_tree (tree), get_operators (tree, operators); kws... )
247+ return node_to_symbolic (
248+ get_tree (tree),
249+ get_operators (tree, operators);
250+ variable_names= get_variable_names (tree, variable_names),
251+ kws... ,
252+ )
224253end
225254
226255function symbolic_to_node (
227256 eqn:: SymbolicUtils.Symbolic ,
228257 operators:: AbstractOperatorEnum ,
229258 :: Type{N} = Node;
230- variable_names:: Union{Array{String,1 },Nothing} = nothing ,
259+ variable_names:: Union{AbstractVector{<:AbstractString },Nothing} = nothing ,
231260 # Deprecated:
232261 varMap= nothing ,
233262) where {N<: AbstractExpressionNode }
0 commit comments