Skip to content

Commit 6c20b96

Browse files
committed
Don't apply local_add_neg_to_sub rewrite if negative variabe is a constant
1 parent f75d89d commit 6c20b96

File tree

2 files changed

+0
-28
lines changed

2 files changed

+0
-28
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,15 +1824,6 @@ def local_add_neg_to_sub(fgraph, node):
18241824
new_out = sub(first, pre_neg)
18251825
return [new_out]
18261826

1827-
# Check if it is a negative constant
1828-
if (
1829-
isinstance(second, TensorConstant)
1830-
and second.unique_value is not None
1831-
and second.unique_value < 0
1832-
):
1833-
new_out = sub(first, np.abs(second.data))
1834-
return [new_out]
1835-
18361827

18371828
@register_canonicalize
18381829
@node_rewriter([mul])

tests/tensor/rewriting/test_math.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4440,25 +4440,6 @@ def test_local_add_neg_to_sub(first_negative):
44404440
assert np.allclose(f(x_test, y_test), exp)
44414441

44424442

4443-
@pytest.mark.parametrize("const_left", (True, False))
4444-
def test_local_add_neg_to_sub_const(const_left):
4445-
x = vector("x")
4446-
const = np.full((3, 2), 5.0)
4447-
out = -const + x if const_left else x + (-const)
4448-
4449-
f = function([x], out, mode=Mode("py"))
4450-
4451-
nodes = [
4452-
node.op
4453-
for node in f.maker.fgraph.toposort()
4454-
if not isinstance(node.op, DimShuffle | Alloc)
4455-
]
4456-
assert nodes == [pt.sub]
4457-
4458-
x_test = np.array([3, 4], dtype=config.floatX)
4459-
assert np.allclose(f(x_test), x_test + (-const))
4460-
4461-
44624443
def test_log1mexp_stabilization():
44634444
mode = Mode("py").including("stabilize")
44644445

0 commit comments

Comments
 (0)