Skip to content

Commit a336fb9

Browse files
committed
Remove internal get_constant helper
Fixes bug in `local_add_neg_to_sub` reported in #584
1 parent 7846b72 commit a336fb9

File tree

2 files changed

+56
-47
lines changed

2 files changed

+56
-47
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
126126
return consts, origconsts, nonconsts
127127

128128

129-
def get_constant(v):
130-
"""
131-
132-
Returns
133-
-------
134-
object
135-
A numeric constant if v is a Constant or, well, a
136-
numeric constant. If v is a plain Variable, returns None.
137-
138-
"""
139-
if isinstance(v, TensorConstant):
140-
return v.unique_value
141-
elif isinstance(v, Variable):
142-
return None
143-
else:
144-
return v
145-
146-
147129
@register_canonicalize
148130
@register_stabilize
149131
@node_rewriter([Dot])
@@ -994,8 +976,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
994976
"""
995977
Find all constants and put them together into a single constant.
996978
997-
Finds all constants in orig_num and orig_denum (using
998-
get_constant) and puts them together into a single
979+
Finds all constants in orig_num and orig_denum
980+
and puts them together into a single
999981
constant. The constant is inserted as the first element of the
1000982
numerator. If the constant is the neutral element, it is
1001983
removed from the numerator.
@@ -1016,17 +998,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
1016998
numct, denumct = [], []
1017999

10181000
for v in orig_num:
1019-
ct = get_constant(v)
1020-
if ct is not None:
1001+
if isinstance(v, TensorConstant) and v.unique_value is not None:
10211002
# We found a constant in the numerator!
10221003
# We add it to numct
1023-
numct.append(ct)
1004+
numct.append(v.unique_value)
10241005
else:
10251006
num.append(v)
10261007
for v in orig_denum:
1027-
ct = get_constant(v)
1028-
if ct is not None:
1029-
denumct.append(ct)
1008+
if isinstance(v, TensorConstant) and v.unique_value is not None:
1009+
denumct.append(v.unique_value)
10301010
else:
10311011
denum.append(v)
10321012

@@ -1050,11 +1030,13 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10501030

10511031
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
10521032
# In that case we should only have one constant in `ct`.
1053-
assert len(ct) == 1
1054-
first_num_ct = get_constant(orig_num[0])
1055-
if first_num_ct is not None and ct[0].type.values_eq(
1056-
ct[0].data, first_num_ct
1057-
):
1033+
[var_ct] = ct
1034+
1035+
num_ct = None
1036+
if isinstance(var_ct, TensorConstant):
1037+
num_ct = var_ct.unique_value
1038+
1039+
if num_ct is not None and var_ct.type.values_eq(var_ct.data, num_ct):
10581040
# This is an important trick :( if it so happens that:
10591041
# * there's exactly one constant on the numerator and none on
10601042
# the denominator
@@ -1839,9 +1821,12 @@ def local_add_neg_to_sub(fgraph, node):
18391821
return [new_out]
18401822

18411823
# Check if it is a negative constant
1842-
const = get_constant(second)
1843-
if const is not None and const < 0:
1844-
new_out = sub(first, np.abs(const))
1824+
if (
1825+
isinstance(second, TensorConstant)
1826+
and second.unique_value is not None
1827+
and second.unique_value < 0
1828+
):
1829+
new_out = sub(first, np.abs(second.data))
18451830
return [new_out]
18461831

18471832

@@ -1870,7 +1855,12 @@ def local_mul_zero(fgraph, node):
18701855
@register_specialize
18711856
@node_rewriter([true_div])
18721857
def local_div_to_reciprocal(fgraph, node):
1873-
if np.all(get_constant(node.inputs[0]) == 1.0):
1858+
if (
1859+
get_underlying_scalar_constant_value(
1860+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1861+
)
1862+
== 1.0
1863+
):
18741864
out = node.outputs[0]
18751865
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
18761866
# The ones could have forced upcasting
@@ -1891,7 +1881,9 @@ def local_reciprocal_canon(fgraph, node):
18911881
@register_canonicalize
18921882
@node_rewriter([pt_pow])
18931883
def local_pow_canonicalize(fgraph, node):
1894-
cst = get_constant(node.inputs[1])
1884+
cst = get_underlying_scalar_constant_value(
1885+
node.inputs[1], only_process_constants=True, raise_not_constant=False
1886+
)
18951887
if cst == 0:
18961888
return [alloc_like(1, node.outputs[0], fgraph)]
18971889
if cst == 1:
@@ -1922,7 +1914,12 @@ def local_intdiv_by_one(fgraph, node):
19221914
@node_rewriter([int_div, true_div])
19231915
def local_zero_div(fgraph, node):
19241916
"""0 / x -> 0"""
1925-
if get_constant(node.inputs[0]) == 0:
1917+
if (
1918+
get_underlying_scalar_constant_value(
1919+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1920+
)
1921+
== 0
1922+
):
19261923
ret = alloc_like(0, node.outputs[0], fgraph)
19271924
ret.tag.values_eq_approx = values_eq_approx_remove_nan
19281925
return [ret]
@@ -1935,8 +1932,12 @@ def local_pow_specialize(fgraph, node):
19351932
odtype = node.outputs[0].dtype
19361933
xsym = node.inputs[0]
19371934
ysym = node.inputs[1]
1938-
y = get_constant(ysym)
1939-
if (y is not None) and not broadcasted_by(xsym, ysym):
1935+
try:
1936+
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
1937+
except NotScalarConstantError:
1938+
return
1939+
1940+
if not broadcasted_by(xsym, ysym):
19401941
rval = None
19411942

19421943
if np.all(y == 2):
@@ -1970,10 +1971,14 @@ def local_pow_to_nested_squaring(fgraph, node):
19701971
"""
19711972

19721973
# the idea here is that we have pow(x, y)
1974+
xsym, ysym = node.inputs
1975+
1976+
try:
1977+
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
1978+
except NotScalarConstantError:
1979+
return
1980+
19731981
odtype = node.outputs[0].dtype
1974-
xsym = node.inputs[0]
1975-
ysym = node.inputs[1]
1976-
y = get_constant(ysym)
19771982

19781983
# the next line is needed to fix a strange case that I don't
19791984
# know how to make a separate test.
@@ -1989,7 +1994,7 @@ def local_pow_to_nested_squaring(fgraph, node):
19891994
y = y[0]
19901995
except IndexError:
19911996
pass
1992-
if (y is not None) and not broadcasted_by(xsym, ysym):
1997+
if not broadcasted_by(xsym, ysym):
19931998
rval = None
19941999
# 512 is too small for the cpu and too big for some gpu!
19952000
if abs(y) == int(abs(y)) and abs(y) <= 512:
@@ -2056,7 +2061,9 @@ def local_mul_specialize(fgraph, node):
20562061
nb_neg_node += 1
20572062

20582063
# remove special case arguments of 1, -1 or 0
2059-
y = get_constant(inp)
2064+
y = get_underlying_scalar_constant_value(
2065+
inp, only_process_constants=True, raise_not_constant=False
2066+
)
20602067
if y == 1.0:
20612068
nb_cst += 1
20622069
elif y == -1.0:

tests/tensor/rewriting/test_math.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4440,11 +4440,13 @@ def test_local_add_neg_to_sub(first_negative):
44404440
assert np.allclose(f(x_test, y_test), exp)
44414441

44424442

4443-
def test_local_add_neg_to_sub_const():
4443+
@pytest.mark.parametrize("const_left", (True, False))
4444+
def test_local_add_neg_to_sub_const(const_left):
44444445
x = vector("x")
4445-
const = 5.0
4446+
const = np.full((3, 2), 5.0)
4447+
out = -const + x if const_left else x + (-const)
44464448

4447-
f = function([x], x + (-const), mode=Mode("py"))
4449+
f = function([x], out, mode=Mode("py"))
44484450

44494451
nodes = [
44504452
node.op

0 commit comments

Comments
 (0)