Skip to content

Commit cff058c

Browse files
committed
Do not apply local_add_neg_to_sub rewrite if negative variabe is a constant
1 parent 04ce1c6 commit cff058c

File tree

2 files changed

+53
-57
lines changed

2 files changed

+53
-57
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node):
535535
@register_stabilize
536536
@register_specialize
537537
@register_canonicalize
538-
@node_rewriter([sub])
538+
@node_rewriter([add, sub])
539539
def local_expm1(fgraph, node):
540-
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
541-
in1, in2 = node.inputs
542-
out = node.outputs[0]
540+
"""Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``."""
541+
if len(node.inputs) != 2:
542+
# TODO: handle more than two inputs in add
543+
return None
543544

544-
if (
545-
in1.owner
546-
and isinstance(in1.owner.op, Elemwise)
547-
and isinstance(in1.owner.op.scalar_op, ps.Exp)
548-
and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1
549-
):
550-
in11 = in1.owner.inputs[0]
551-
new_out = expm1(in11)
545+
if isinstance(node.op.scalar_op, ps.Sub):
546+
exp_x, other_inp = node.inputs
547+
if not (
548+
exp_x.owner
549+
and isinstance(exp_x.owner.op, Elemwise)
550+
and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
551+
and get_underlying_scalar_constant_value(
552+
other_inp, raise_not_constant=False
553+
)
554+
== 1
555+
):
556+
return None
557+
else:
558+
# Try both orders
559+
other_inp, exp_x = node.inputs
560+
for i in range(2):
561+
if i == 1:
562+
other_inp, exp_x = exp_x, other_inp
563+
if (
564+
exp_x.owner
565+
and isinstance(exp_x.owner.op, Elemwise)
566+
and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
567+
and get_underlying_scalar_constant_value(
568+
other_inp, raise_not_constant=False
569+
)
570+
== -1
571+
):
572+
break
573+
else: # no break
574+
return None
552575

553-
if new_out.type.broadcastable != out.type.broadcastable:
554-
new_out = broadcast_arrays(in11, in2)[0]
576+
[old_out] = node.outputs
555577

556-
if new_out.dtype != out.dtype:
557-
new_out = cast(new_out, dtype=out.dtype)
578+
[x] = exp_x.owner.inputs
579+
if x.type.broadcastable != old_out.type.broadcastable:
580+
x = broadcast_arrays(x, other_inp)[0]
558581

559-
if not out.type.is_super(new_out.type):
560-
return
561-
return [new_out]
582+
new_out = expm1(x)
583+
584+
if new_out.dtype != old_out.dtype:
585+
new_out = cast(new_out, dtype=old_out.dtype)
586+
587+
if not old_out.type.is_super(new_out.type):
588+
return None
589+
590+
return [new_out]
562591

563592

564593
@register_specialize
@@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node):
18241853
new_out = sub(first, pre_neg)
18251854
return [new_out]
18261855

1827-
# Check if it is a negative constant
1828-
if (
1829-
isinstance(second, TensorConstant)
1830-
and second.unique_value is not None
1831-
and second.unique_value < 0
1832-
):
1833-
new_out = sub(first, np.abs(second.data))
1834-
return [new_out]
1835-
18361856

18371857
@register_canonicalize
18381858
@node_rewriter([mul])
@@ -2606,9 +2626,9 @@ def local_greedy_distributor(fgraph, node):
26062626
register_stabilize(local_one_minus_erfc)
26072627
register_specialize(local_one_minus_erfc)
26082628

2609-
# erfc(-x)-1=>erf(x)
2629+
# -1 + erfc(-x)=>erf(x)
26102630
local_erf_neg_minus_one = PatternNodeRewriter(
2611-
(sub, (erfc, (neg, "x")), 1),
2631+
(add, -1, (erfc, (neg, "x"))),
26122632
(erf, "x"),
26132633
allow_multiple_clients=True,
26142634
name="local_erf_neg_minus_one",

tests/tensor/rewriting/test_math.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3806,14 +3806,9 @@ def test_local_expm1():
38063806
for n in h.maker.fgraph.toposort()
38073807
)
38083808

3809-
# This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked
3810-
expect_rewrite = config.mode != "FAST_COMPILE"
3811-
assert (
3812-
any(
3813-
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1)
3814-
for n in r.maker.fgraph.toposort()
3815-
)
3816-
== expect_rewrite
3809+
assert any(
3810+
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1)
3811+
for n in r.maker.fgraph.toposort()
38173812
)
38183813

38193814

@@ -4440,25 +4435,6 @@ def test_local_add_neg_to_sub(first_negative):
44404435
assert np.allclose(f(x_test, y_test), exp)
44414436

44424437

4443-
@pytest.mark.parametrize("const_left", (True, False))
4444-
def test_local_add_neg_to_sub_const(const_left):
4445-
x = vector("x")
4446-
const = np.full((3, 2), 5.0)
4447-
out = -const + x if const_left else x + (-const)
4448-
4449-
f = function([x], out, mode=Mode("py"))
4450-
4451-
nodes = [
4452-
node.op
4453-
for node in f.maker.fgraph.toposort()
4454-
if not isinstance(node.op, DimShuffle | Alloc)
4455-
]
4456-
assert nodes == [pt.sub]
4457-
4458-
x_test = np.array([3, 4], dtype=config.floatX)
4459-
assert np.allclose(f(x_test), x_test + (-const))
4460-
4461-
44624438
def test_log1mexp_stabilization():
44634439
mode = Mode("py").including("stabilize")
44644440

0 commit comments

Comments
 (0)