Skip to content

Commit e469bad

Browse files
committed
fix: missing extraction of operators in SymbolicUtils convert
1 parent 0d1a94d commit e469bad

File tree

2 files changed

+86
-26
lines changed

2 files changed

+86
-26
lines changed

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
module DynamicExpressionsSymbolicUtilsExt
22

3-
using SymbolicUtils
4-
import DynamicExpressions.NodeModule:
3+
using DynamicExpressions:
4+
AbstractExpression, get_tree, get_operators, get_variable_names, default_node_type
5+
using DynamicExpressions.NodeModule:
56
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
6-
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
7-
import DynamicExpressions.ValueInterfaceModule: is_valid
8-
import DynamicExpressions.UtilsModule: deprecate_varmap
7+
using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
8+
using DynamicExpressions.UtilsModule: deprecate_varmap
9+
10+
using SymbolicUtils
11+
912
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
10-
import DynamicExpressions: AbstractExpression, get_tree, get_operators
13+
import DynamicExpressions.ValueInterfaceModule: is_valid
1114

1215
const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}
1316
const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /)
@@ -88,7 +91,7 @@ function split_eq(
8891
args,
8992
operators::AbstractOperatorEnum,
9093
::Type{N}=Node;
91-
variable_names::Union{Array{String,1},Nothing}=nothing,
94+
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
9295
# Deprecated:
9396
varMap=nothing,
9497
) where {N<:AbstractExpressionNode}
@@ -103,8 +106,8 @@ function split_eq(
103106
end
104107
return constructorof(N)(;
105108
op=ind,
106-
l=convert(N, args[1], operators; variable_names=variable_names),
107-
r=convert(N, op(args[2:end]...), operators; variable_names=variable_names),
109+
l=convert(N, args[1], operators; variable_names),
110+
r=convert(N, op(args[2:end]...), operators; variable_names),
108111
)
109112
end
110113

@@ -119,17 +122,14 @@ function Base.convert(
119122
::typeof(SymbolicUtils.Symbolic),
120123
tree::Union{AbstractExpression,AbstractExpressionNode},
121124
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
122-
variable_names::Union{Array{String,1},Nothing}=nothing,
125+
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
123126
index_functions::Bool=false,
124127
# Deprecated:
125128
varMap=nothing,
126129
)
127130
variable_names = deprecate_varmap(variable_names, varMap, :convert)
128131
return node_to_symbolic(
129-
tree,
130-
get_operators(tree, operators);
131-
variable_names=variable_names,
132-
index_functions=index_functions,
132+
tree, get_operators(tree, operators); variable_names, index_functions
133133
)
134134
end
135135

@@ -143,11 +143,19 @@ function Base.convert(
143143
::Type{N},
144144
expr::SymbolicUtils.Symbolic,
145145
operators::AbstractOperatorEnum;
146-
variable_names::Union{Array{String,1},Nothing}=nothing,
146+
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
147147
) where {N<:AbstractExpressionNode}
148148
variable_names = deprecate_varmap(variable_names, nothing, :convert)
149149
if !iscall(expr)
150-
variable_names === nothing && return constructorof(N)(String(expr.name))
150+
if variable_names === nothing
151+
s = String(expr.name)
152+
# Verify it is of the format "x{num}":
153+
@assert(
154+
occursin(r"^x\d+$", s),
155+
"Variable name $s is not of the format x{num}. Please provide the `variable_names` explicitly."
156+
)
157+
return constructorof(N)(s)
158+
end
151159
return constructorof(N)(String(expr.name), variable_names)
152160
end
153161

@@ -160,23 +168,36 @@ function Base.convert(
160168
op = convert_to_function(SymbolicUtils.operation(expr), operators)
161169
args = SymbolicUtils.arguments(expr)
162170

163-
length(args) > 2 &&
164-
return split_eq(op, args, operators, N; variable_names=variable_names)
171+
length(args) > 2 && return split_eq(op, args, operators, N; variable_names)
165172
ind = if length(args) == 2
166173
findoperation(op, operators.binops)
167174
else
168175
findoperation(op, operators.unaops)
169176
end
170177

171178
return constructorof(N)(;
172-
op=ind,
173-
children=map(x -> convert(N, x, operators; variable_names=variable_names), args),
179+
op=ind, children=map(x -> convert(N, x, operators; variable_names), args)
174180
)
175181
end
176182

183+
_node_type(::Type{<:AbstractExpression{T,N}}) where {T,N<:AbstractExpressionNode} = N
184+
_node_type(::Type{E}) where {E<:AbstractExpression} = default_node_type(E)
185+
186+
function Base.convert(
187+
::Type{E},
188+
x::Union{SymbolicUtils.Symbolic,Number},
189+
operators::AbstractOperatorEnum;
190+
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
191+
kws...,
192+
) where {E<:AbstractExpression}
193+
N = _node_type(E)
194+
tree = convert(N, x, operators; variable_names)
195+
return constructorof(E)(tree; operators, variable_names, kws...)
196+
end
197+
177198
"""
178199
node_to_symbolic(tree::AbstractExpressionNode, operators::AbstractOperatorEnum;
179-
variable_names::Union{Array{String, 1}, Nothing}=nothing,
200+
variable_names::Union{AbstractVector{<:AbstractString}, Nothing}=nothing,
180201
index_functions::Bool=false)
181202
182203
The interface to SymbolicUtils.jl. Passing a tree to this function
@@ -186,7 +207,7 @@ will generate a symbolic equation in SymbolicUtils.jl format.
186207
187208
- `tree::AbstractExpressionNode`: The equation to convert.
188209
- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
189-
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for
210+
- `variable_names::Union{AbstractVector{<:AbstractString}, Nothing}=nothing`: What variable names to use for
190211
each feature. Default is [x1, x2, x3, ...].
191212
- `index_functions::Bool=false`: Whether to generate special names for the
192213
operators, which then allows one to convert back to a `AbstractExpressionNode` format
@@ -196,7 +217,7 @@ will generate a symbolic equation in SymbolicUtils.jl format.
196217
function node_to_symbolic(
197218
tree::AbstractExpressionNode,
198219
operators::AbstractOperatorEnum;
199-
variable_names::Union{Array{String,1},Nothing}=nothing,
220+
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
200221
index_functions::Bool=false,
201222
# Deprecated:
202223
varMap=nothing,
@@ -218,16 +239,24 @@ function node_to_symbolic(
218239
return substitute(expr, subs)
219240
end
220241
function node_to_symbolic(
221-
tree::AbstractExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...
242+
tree::AbstractExpression,
243+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
244+
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
245+
kws...,
222246
)
223-
return node_to_symbolic(get_tree(tree), get_operators(tree, operators); kws...)
247+
return node_to_symbolic(
248+
get_tree(tree),
249+
get_operators(tree, operators);
250+
variable_names=get_variable_names(tree, variable_names),
251+
kws...,
252+
)
224253
end
225254

226255
function symbolic_to_node(
227256
eqn::SymbolicUtils.Symbolic,
228257
operators::AbstractOperatorEnum,
229258
::Type{N}=Node;
230-
variable_names::Union{Array{String,1},Nothing}=nothing,
259+
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
231260
# Deprecated:
232261
varMap=nothing,
233262
) where {N<:AbstractExpressionNode}

test/test_symbolic_utils.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using SymbolicUtils
22
using DynamicExpressions
3+
using DynamicExpressions: get_operators, get_variable_names
34
using Test
45
include("test_params.jl")
56

@@ -28,3 +29,33 @@ eqn = node_to_symbolic(tree, operators; variable_names=["energy"], index_functio
2829

2930
tree2 = symbolic_to_node(eqn, operators; variable_names=["energy"])
3031
@test string_tree(tree, operators) == string_tree(tree2, operators)
32+
33+
# Test variable name conversion with Expression objects
34+
let
35+
ex = parse_expression(
36+
:(sin(x + y));
37+
binary_operators=[+, *, -, /],
38+
unary_operators=[sin],
39+
variable_names=["x", "y"],
40+
)
41+
42+
# Test conversion to symbolic form preserves variable names
43+
eqn = convert(SymbolicUtils.Symbolic, ex)
44+
@test string(eqn) == "sin(x + y)"
45+
46+
# Test with different variable names in the expression
47+
ex2 = parse_expression(
48+
:(sin(alpha + beta));
49+
binary_operators=[+, *, -, /],
50+
unary_operators=[sin],
51+
variable_names=["alpha", "beta"],
52+
)
53+
eqn2 = convert(SymbolicUtils.Symbolic, ex2)
54+
@test string(eqn2) == "sin(alpha + beta)"
55+
eqn2
56+
57+
# Test round trip preserves structure and variable names
58+
operators = OperatorEnum(; unary_operators=(sin,), binary_operators=(+, *, -, /))
59+
ex2_again = convert(Expression, eqn2, operators; variable_names=["alpha", "beta"])
60+
@test ex2 == ex2_again
61+
end

0 commit comments

Comments
 (0)