Skip to content

Commit 30f5b3d

Browse files
committed
Remove internal get_constant helper
Fixes bug in `local_add_neg_to_sub` reported in #584
1 parent a283a96 commit 30f5b3d

File tree

2 files changed

+56
-47
lines changed

2 files changed

+56
-47
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
126126
return consts, origconsts, nonconsts
127127

128128

129-
def get_constant(v):
130-
"""
131-
132-
Returns
133-
-------
134-
object
135-
A numeric constant if v is a Constant or, well, a
136-
numeric constant. If v is a plain Variable, returns None.
137-
138-
"""
139-
if isinstance(v, TensorConstant):
140-
return v.unique_value
141-
elif isinstance(v, Variable):
142-
return None
143-
else:
144-
return v
145-
146-
147129
@register_canonicalize
148130
@register_stabilize
149131
@node_rewriter([Dot])
@@ -980,8 +962,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
980962
"""
981963
Find all constants and put them together into a single constant.
982964
983-
Finds all constants in orig_num and orig_denum (using
984-
get_constant) and puts them together into a single
965+
Finds all constants in orig_num and orig_denum
966+
and puts them together into a single
985967
constant. The constant is inserted as the first element of the
986968
numerator. If the constant is the neutral element, it is
987969
removed from the numerator.
@@ -1002,17 +984,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
1002984
numct, denumct = [], []
1003985

1004986
for v in orig_num:
1005-
ct = get_constant(v)
1006-
if ct is not None:
987+
if isinstance(v, TensorConstant) and v.unique_value is not None:
1007988
# We found a constant in the numerator!
1008989
# We add it to numct
1009-
numct.append(ct)
990+
numct.append(v.unique_value)
1010991
else:
1011992
num.append(v)
1012993
for v in orig_denum:
1013-
ct = get_constant(v)
1014-
if ct is not None:
1015-
denumct.append(ct)
994+
if isinstance(v, TensorConstant) and v.unique_value is not None:
995+
denumct.append(v.unique_value)
1016996
else:
1017997
denum.append(v)
1018998

@@ -1036,11 +1016,13 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10361016

10371017
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
10381018
# In that case we should only have one constant in `ct`.
1039-
assert len(ct) == 1
1040-
first_num_ct = get_constant(orig_num[0])
1041-
if first_num_ct is not None and ct[0].type.values_eq(
1042-
ct[0].data, first_num_ct
1043-
):
1019+
[var_ct] = ct
1020+
1021+
num_ct = None
1022+
if isinstance(var_ct, TensorConstant):
1023+
num_ct = var_ct.unique_value
1024+
1025+
if num_ct is not None and var_ct.type.values_eq(var_ct.data, num_ct):
10441026
# This is an important trick :( if it so happens that:
10451027
# * there's exactly one constant on the numerator and none on
10461028
# the denominator
@@ -1825,9 +1807,12 @@ def local_add_neg_to_sub(fgraph, node):
18251807
return [new_out]
18261808

18271809
# Check if it is a negative constant
1828-
const = get_constant(second)
1829-
if const is not None and const < 0:
1830-
new_out = sub(first, np.abs(const))
1810+
if (
1811+
isinstance(second, TensorConstant)
1812+
and second.unique_value is not None
1813+
and second.unique_value < 0
1814+
):
1815+
new_out = sub(first, np.abs(second.data))
18311816
return [new_out]
18321817

18331818

@@ -1856,7 +1841,12 @@ def local_mul_zero(fgraph, node):
18561841
@register_specialize
18571842
@node_rewriter([true_div])
18581843
def local_div_to_reciprocal(fgraph, node):
1859-
if np.all(get_constant(node.inputs[0]) == 1.0):
1844+
if (
1845+
get_underlying_scalar_constant_value(
1846+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1847+
)
1848+
== 1.0
1849+
):
18601850
out = node.outputs[0]
18611851
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
18621852
# The ones could have forced upcasting
@@ -1877,7 +1867,9 @@ def local_reciprocal_canon(fgraph, node):
18771867
@register_canonicalize
18781868
@node_rewriter([pt_pow])
18791869
def local_pow_canonicalize(fgraph, node):
1880-
cst = get_constant(node.inputs[1])
1870+
cst = get_underlying_scalar_constant_value(
1871+
node.inputs[1], only_process_constants=True, raise_not_constant=False
1872+
)
18811873
if cst == 0:
18821874
return [alloc_like(1, node.outputs[0], fgraph)]
18831875
if cst == 1:
@@ -1908,7 +1900,12 @@ def local_intdiv_by_one(fgraph, node):
19081900
@node_rewriter([int_div, true_div])
19091901
def local_zero_div(fgraph, node):
19101902
"""0 / x -> 0"""
1911-
if get_constant(node.inputs[0]) == 0:
1903+
if (
1904+
get_underlying_scalar_constant_value(
1905+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1906+
)
1907+
== 0
1908+
):
19121909
ret = alloc_like(0, node.outputs[0], fgraph)
19131910
ret.tag.values_eq_approx = values_eq_approx_remove_nan
19141911
return [ret]
@@ -1921,8 +1918,12 @@ def local_pow_specialize(fgraph, node):
19211918
odtype = node.outputs[0].dtype
19221919
xsym = node.inputs[0]
19231920
ysym = node.inputs[1]
1924-
y = get_constant(ysym)
1925-
if (y is not None) and not broadcasted_by(xsym, ysym):
1921+
try:
1922+
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
1923+
except NotScalarConstantError:
1924+
return
1925+
1926+
if not broadcasted_by(xsym, ysym):
19261927
rval = None
19271928

19281929
if np.all(y == 2):
@@ -1956,10 +1957,14 @@ def local_pow_to_nested_squaring(fgraph, node):
19561957
"""
19571958

19581959
# the idea here is that we have pow(x, y)
1960+
xsym, ysym = node.inputs
1961+
1962+
try:
1963+
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
1964+
except NotScalarConstantError:
1965+
return
1966+
19591967
odtype = node.outputs[0].dtype
1960-
xsym = node.inputs[0]
1961-
ysym = node.inputs[1]
1962-
y = get_constant(ysym)
19631968

19641969
# the next line is needed to fix a strange case that I don't
19651970
# know how to make a separate test.
@@ -1975,7 +1980,7 @@ def local_pow_to_nested_squaring(fgraph, node):
19751980
y = y[0]
19761981
except IndexError:
19771982
pass
1978-
if (y is not None) and not broadcasted_by(xsym, ysym):
1983+
if not broadcasted_by(xsym, ysym):
19791984
rval = None
19801985
# 512 is too small for the cpu and too big for some gpu!
19811986
if abs(y) == int(abs(y)) and abs(y) <= 512:
@@ -2042,7 +2047,9 @@ def local_mul_specialize(fgraph, node):
20422047
nb_neg_node += 1
20432048

20442049
# remove special case arguments of 1, -1 or 0
2045-
y = get_constant(inp)
2050+
y = get_underlying_scalar_constant_value(
2051+
inp, only_process_constants=True, raise_not_constant=False
2052+
)
20462053
if y == 1.0:
20472054
nb_cst += 1
20482055
elif y == -1.0:

tests/tensor/rewriting/test_math.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4440,11 +4440,13 @@ def test_local_add_neg_to_sub(first_negative):
44404440
assert np.allclose(f(x_test, y_test), exp)
44414441

44424442

4443-
def test_local_add_neg_to_sub_const():
4443+
@pytest.mark.parametrize("const_left", (True, False))
4444+
def test_local_add_neg_to_sub_const(const_left):
44444445
x = vector("x")
4445-
const = 5.0
4446+
const = np.full((3, 2), 5.0)
4447+
out = -const + x if const_left else x + (-const)
44464448

4447-
f = function([x], x + (-const), mode=Mode("py"))
4449+
f = function([x], out, mode=Mode("py"))
44484450

44494451
nodes = [
44504452
node.op

0 commit comments

Comments
 (0)