Skip to content

Commit 8a293e6

Browse files
fix: use term() guards to match master behavior for commutative ops
Adds SymbolicUtils.term() usage to prevent SUv4 from canonicalizing x*x to x^2 during conversion. This makes the backport behaviorally equivalent to upstream/master. - Import term from SymbolicUtils - Use term(op, ...) for * and + in parse_tree_to_eqs - Use term(*, ...) in multiply_powers to build explicit products - Restore test to check x1*x1 stays as explicit multiplication
1 parent eb77e9c commit 8a293e6

File tree

2 files changed

+51
-24
lines changed

2 files changed

+51
-24
lines changed

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
88
using DynamicExpressions.UtilsModule: deprecate_varmap
99

1010
using SymbolicUtils
11-
using SymbolicUtils: BasicSymbolic, SymReal, iscall, issym, isconst, unwrap_const
11+
using SymbolicUtils: BasicSymbolic, SymReal, iscall, issym, isconst, unwrap_const, term
1212

1313
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
1414
import DynamicExpressions.ValueInterfaceModule: is_valid
@@ -69,6 +69,15 @@ function parse_tree_to_eqs(
6969
# Convert children to symbolic form
7070
sym_children = map(x -> parse_tree_to_eqs(x, operators, index_functions), children)
7171

72+
# SymbolicUtils v4 may canonicalize some commutative operations at construction time
73+
# (e.g. `x*x` -> `x^2`, or reordering `a + b`).
74+
#
75+
# For stable round-trips (and to avoid introducing `^` when it's not in the operator set),
76+
# construct commutative ops as explicit terms.
77+
if op === (*) || op === (+)
78+
return subs_bad(term(op, sym_children...))
79+
end
80+
7281
return subs_bad(op(sym_children...))
7382
end
7483

@@ -320,11 +329,22 @@ function multiply_powers(
320329
if n_val == 1
321330
return l, true
322331
elseif n_val == -1
323-
return 1.0 / l, true
332+
return term(/, 1.0, l), true
324333
elseif n_val > 1
325-
return reduce(*, [l for i in 1:n_val]), true
334+
# IMPORTANT: use `term(*, ...)` to prevent SymbolicUtils from immediately
335+
# canonicalizing `l*l` back into `l^2`.
336+
out = l
337+
for _ in 2:n_val
338+
out = term(*, out, l)
339+
end
340+
return out, true
326341
elseif n_val < -1
327-
return reduce(/, vcat([1], [l for i in 1:abs(n_val)])), true
342+
# Build 1/(l*l*...) using explicit multiplication terms.
343+
denom = l
344+
for _ in 2:abs(n_val)
345+
denom = term(*, denom, l)
346+
end
347+
return term(/, 1.0, denom), true
328348
else
329349
return 1.0, true
330350
end
@@ -340,6 +360,11 @@ function multiply_powers(
340360
r, complete2 = multiply_powers(args[2])
341361
@return_on_false complete2 eqn
342362
@return_on_false is_valid(r) eqn
363+
# SymbolicUtils v4 normalizes `x*x` into `x^2` via the `*` method; preserve
364+
# explicit multiplication terms so we don't introduce `^` during conversion.
365+
if op == *
366+
return term(op, l, r), true
367+
end
343368
return op(l, r), true
344369
else
345370
# return tree_mapreduce(multiply_powers, op, args)
@@ -351,7 +376,8 @@ function multiply_powers(
351376
end
352377
cumulator = out[1][1]
353378
for i in 2:size(out, 1)
354-
cumulator = op(cumulator, out[i][1])
379+
cumulator =
380+
(op == *) ? term(op, cumulator, out[i][1]) : op(cumulator, out[i][1])
355381
@return_on_false is_valid(cumulator) eqn
356382
end
357383
return cumulator, true

test/test_simplification.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,37 @@ tree = convert(Node, eqn2, operators)
4141
# Make sure the other node is x1:
4242
@test (!tree.l.constant ? tree.l : tree.r).feature == 1
4343

44-
# SymbolicUtils v4 automatically simplifies x1*x1 to x1^2
45-
# For round-trip to work, we need ^ in the operator set
46-
operators_with_pow = OperatorEnum(; binary_operators=(+, -, /, *, ^))
44+
# Finally, let's try converting a product, and ensure
45+
# that SymbolicUtils does not convert it to a power:
4746
tree = Node("x1") * Node("x1")
48-
eqn = convert(BasicSymbolic, tree, operators_with_pow)
49-
# The symbolic repr will be x1^2 in SymbolicUtils v4
50-
@test occursin("x1", repr(eqn))
51-
# Test converting back (x^2 comes back as x^2 since ^ is in operators):
52-
tree_copy = convert(Node, eqn, operators_with_pow)
53-
# The structure is preserved as a power in v4
54-
@test occursin("x1", repr(tree_copy))
55-
56-
# Let's test a more complex function with supported operators
57-
# (Custom operators are not supported in SymbolicUtils v4+)
58-
operators = OperatorEnum(; binary_operators=(+, *, -, /), unary_operators=(cos, exp, sin))
47+
eqn = convert(BasicSymbolic, tree, operators)
48+
@test repr(eqn) "x1*x1"
49+
# Test converting back:
50+
tree_copy = convert(Node, eqn, operators)
51+
@test repr(tree_copy) "(x1*x1)"
52+
53+
# Let's test a more complex function. In SymbolicUtils v4+, custom operators need
54+
# `index_functions=true` to round-trip.
5955

6056
x1, x2, x3 = Node("x1"), Node("x2"), Node("x3")
57+
pow_abs2(x, y) = abs(x)^y
58+
59+
operators = OperatorEnum(;
60+
binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin)
61+
)
6162
@extend_operators operators
6263
tree = (
63-
((x2 + x2) * ((-0.5982493 / (x1 * x2)) / -0.54734415)) + (
64+
((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + (
6465
sin(
65-
cos(
66+
custom_cos(
6667
sin(1.2926733 - 1.6606787) /
6768
sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426),
68-
) * (cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)),
69-
) / (0.14854191 - ((cos(x2) * -1.6047639) - 0.023943262))
69+
) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)),
70+
) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262))
7071
)
7172
)
7273
# Convert to symbolic form
73-
eqn = convert(BasicSymbolic, tree, operators)
74+
eqn = convert(BasicSymbolic, tree, operators; index_functions=true)
7475

7576
tree_copy = convert(Node, eqn, operators)
7677
tree_copy2 = convert(Node, simplify(eqn), operators)

0 commit comments

Comments
 (0)