Skip to content

Commit 1461a39

Browse files
committed
Remove internal get_constant helper
Fixes bug in `local_add_neg_to_sub` reported in #584
1 parent 70c64a0 commit 1461a39

File tree

2 files changed

+48
-43
lines changed

2 files changed

+48
-43
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
125125
return consts, origconsts, nonconsts
126126

127127

128-
def get_constant(v):
129-
"""
130-
131-
Returns
132-
-------
133-
object
134-
A numeric constant if v is a Constant or, well, a
135-
numeric constant. If v is a plain Variable, returns None.
136-
137-
"""
138-
if isinstance(v, TensorConstant):
139-
return v.unique_value
140-
elif isinstance(v, Variable):
141-
return None
142-
else:
143-
return v
144-
145-
146128
@register_canonicalize
147129
@register_stabilize
148130
@node_rewriter([Dot])
@@ -1021,8 +1003,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10211003
"""
10221004
Find all constants and put them together into a single constant.
10231005
1024-
Finds all constants in orig_num and orig_denum (using
1025-
get_constant) and puts them together into a single
1006+
Finds all constants in orig_num and orig_denum
1007+
and puts them together into a single
10261008
constant. The constant is inserted as the first element of the
10271009
numerator. If the constant is the neutral element, it is
10281010
removed from the numerator.
@@ -1043,17 +1025,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10431025
numct, denumct = [], []
10441026

10451027
for v in orig_num:
1046-
ct = get_constant(v)
1047-
if ct is not None:
1028+
if isinstance(v, TensorConstant) and v.unique_value is not None:
10481029
# We found a constant in the numerator!
10491030
# We add it to numct
1050-
numct.append(ct)
1031+
numct.append(v.unique_value)
10511032
else:
10521033
num.append(v)
10531034
for v in orig_denum:
1054-
ct = get_constant(v)
1055-
if ct is not None:
1056-
denumct.append(ct)
1035+
if isinstance(v, TensorConstant) and v.unique_value is not None:
1036+
denumct.append(v.unique_value)
10571037
else:
10581038
denum.append(v)
10591039

@@ -1077,11 +1057,13 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10771057

10781058
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
10791059
# In that case we should only have one constant in `ct`.
1080-
assert len(ct) == 1
1081-
first_num_ct = get_constant(orig_num[0])
1082-
if first_num_ct is not None and ct[0].type.values_eq(
1083-
ct[0].data, first_num_ct
1084-
):
1060+
[var_ct] = ct
1061+
1062+
num_ct = None
1063+
if isinstance(var_ct, TensorConstant):
1064+
num_ct = var_ct.unique_value
1065+
1066+
if num_ct is not None and var_ct.type.values_eq(var_ct.data, num_ct):
10851067
# This is an important trick :( if it so happens that:
10861068
# * there's exactly one constant on the numerator and none on
10871069
# the denominator
@@ -1864,9 +1846,12 @@ def local_add_neg_to_sub(fgraph, node):
18641846
return [new_out]
18651847

18661848
# Check if it is a negative constant
1867-
const = get_constant(second)
1868-
if const is not None and const < 0:
1869-
new_out = sub(first, np.abs(const))
1849+
if (
1850+
isinstance(second, TensorConstant)
1851+
and second.unique_value is not None
1852+
and second.unique_value < 0
1853+
):
1854+
new_out = sub(first, np.abs(second.data))
18701855
return [new_out]
18711856

18721857

@@ -1895,7 +1880,12 @@ def local_mul_zero(fgraph, node):
18951880
@register_specialize
18961881
@node_rewriter([true_div])
18971882
def local_div_to_reciprocal(fgraph, node):
1898-
if np.all(get_constant(node.inputs[0]) == 1.0):
1883+
if (
1884+
get_underlying_scalar_constant_value(
1885+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1886+
)
1887+
== 1.0
1888+
):
18991889
out = node.outputs[0]
19001890
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
19011891
# The ones could have forced upcasting
@@ -1916,7 +1906,9 @@ def local_reciprocal_canon(fgraph, node):
19161906
@register_canonicalize
19171907
@node_rewriter([pt_pow])
19181908
def local_pow_canonicalize(fgraph, node):
1919-
cst = get_constant(node.inputs[1])
1909+
cst = get_underlying_scalar_constant_value(
1910+
node.inputs[1], only_process_constants=True, raise_not_constant=False
1911+
)
19201912
if cst == 0:
19211913
return [alloc_like(1, node.outputs[0], fgraph)]
19221914
if cst == 1:
@@ -1947,7 +1939,12 @@ def local_intdiv_by_one(fgraph, node):
19471939
@node_rewriter([int_div, true_div])
19481940
def local_zero_div(fgraph, node):
19491941
"""0 / x -> 0"""
1950-
if get_constant(node.inputs[0]) == 0:
1942+
if (
1943+
get_underlying_scalar_constant_value(
1944+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1945+
)
1946+
== 0
1947+
):
19511948
ret = alloc_like(0, node.outputs[0], fgraph)
19521949
ret.tag.values_eq_approx = values_eq_approx_remove_nan
19531950
return [ret]
@@ -1960,7 +1957,9 @@ def local_pow_specialize(fgraph, node):
19601957
odtype = node.outputs[0].dtype
19611958
xsym = node.inputs[0]
19621959
ysym = node.inputs[1]
1963-
y = get_constant(ysym)
1960+
y = get_underlying_scalar_constant_value(
1961+
ysym, only_process_constants=True, raise_not_constant=False
1962+
)
19641963
if (y is not None) and not broadcasted_by(xsym, ysym):
19651964
rval = None
19661965

@@ -1998,7 +1997,9 @@ def local_pow_to_nested_squaring(fgraph, node):
19981997
odtype = node.outputs[0].dtype
19991998
xsym = node.inputs[0]
20001999
ysym = node.inputs[1]
2001-
y = get_constant(ysym)
2000+
y = get_underlying_scalar_constant_value(
2001+
ysym, only_process_constants=True, raise_not_constant=False
2002+
)
20022003

20032004
# the next line is needed to fix a strange case that I don't
20042005
# know how to make a separate test.
@@ -2081,7 +2082,9 @@ def local_mul_specialize(fgraph, node):
20812082
nb_neg_node += 1
20822083

20832084
# remove special case arguments of 1, -1 or 0
2084-
y = get_constant(inp)
2085+
y = get_underlying_scalar_constant_value(
2086+
inp, raise_not_constant=False, only_process_constants=True
2087+
)
20852088
if y == 1.0:
20862089
nb_cst += 1
20872090
elif y == -1.0:

tests/tensor/rewriting/test_math.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4376,11 +4376,13 @@ def test_local_add_neg_to_sub(first_negative):
43764376
assert np.allclose(f(x_test, y_test), exp)
43774377

43784378

4379-
def test_local_add_neg_to_sub_const():
4379+
@pytest.mark.parametrize("const_left", (True, False))
4380+
def test_local_add_neg_to_sub_const(const_left):
43804381
x = vector("x")
4381-
const = 5.0
4382+
const = np.full((3, 2), 5.0)
4383+
out = -const + x if const_left else x + (-const)
43824384

4383-
f = function([x], x + (-const), mode=Mode("py"))
4385+
f = function([x], out, mode=Mode("py"))
43844386

43854387
nodes = [
43864388
node.op

0 commit comments

Comments
 (0)