Skip to content

Commit 5e7fd29

Browse files
author
Luca Citi
committed
Address #1497 by changing instances of np.isclose to a function isclose, which uses 10 ULPs by default
1 parent 227a468 commit 5e7fd29

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,7 +2385,7 @@ def local_log1p(fgraph, node):
23852385
log_arg.owner.inputs, only_process_constants=True
23862386
)
23872387
# scalar_inputs are potentially dimshuffled and fill'd scalars
2388-
if scalars and np.allclose(np.sum(scalars), 1):
2388+
if scalars and isclose(np.sum(scalars), 1):
23892389
if nonconsts:
23902390
ninp = variadic_add(*nonconsts)
23912391
if ninp.dtype != log_arg.type.dtype:
@@ -2990,6 +2990,19 @@ def check_input(inputs):
29902990
return [ret]
29912991

29922992

2993+
def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
2994+
"""
2995+
2996+
Returns
2997+
-------
2998+
bool
2999+
True iff x is a constant close to ref (by default 10 ULPs).
3000+
3001+
"""
3002+
atol = atol + num_ulps * np.spacing(x)
3003+
return np.allclose(x, ref, rtol=rtol, atol=atol)
3004+
3005+
29933006
def _skip_mul_1(r):
29943007
if r.owner and r.owner.op == mul:
29953008
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
@@ -3008,7 +3021,7 @@ def _is_1(expr):
30083021
"""
30093022
try:
30103023
v = get_underlying_scalar_constant_value(expr)
3011-
return np.isclose(v, 1)
3024+
return isclose(v, 1)
30123025
except NotScalarConstantError:
30133026
return False
30143027

@@ -3069,7 +3082,7 @@ def is_1pexp(t, only_process_constants=True):
30693082
scal_sum = scalars[0]
30703083
for s in scalars[1:]:
30713084
scal_sum = scal_sum + s
3072-
if np.allclose(scal_sum, 1):
3085+
if isclose(scal_sum, 1):
30733086
return False, maybe_exp.owner.inputs[0]
30743087
return None
30753088

@@ -3169,7 +3182,7 @@ def is_neg(var):
31693182
for idx, mul_input in enumerate(var_node.inputs):
31703183
try:
31713184
constant = get_underlying_scalar_constant_value(mul_input)
3172-
is_minus_1 = np.isclose(constant, -1)
3185+
is_minus_1 = isclose(constant, -1)
31733186
except NotScalarConstantError:
31743187
is_minus_1 = False
31753188
if is_minus_1:
@@ -3577,7 +3590,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
35773590
# scalar_inputs are potentially dimshuffled and fill'd scalars
35783591
if len(nonconsts) == 1:
35793592
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
3580-
if scalars_ and np.allclose(np.sum(scalars_), 1):
3593+
if scalars_ and isclose(np.sum(scalars_), 1):
35813594
out = [
35823595
alloc_like(
35833596
sigmoid(neg(nonconsts[0].owner.inputs[0])),

tests/tensor/rewriting/test_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4299,7 +4299,7 @@ def test_log1msigm_to_softplus(self):
42994299
f(np.random.random((54, 11)).astype(config.floatX))
43004300

43014301
# Test close to 1
4302-
out = log(1.000001 - sigmoid(x))
4302+
out = log(np.nextafter(1.0, 2.0) - sigmoid(x))
43034303
f = pytensor.function([x], out, mode=self.m)
43044304
topo = f.maker.fgraph.toposort()
43054305
assert len(topo) == 2

0 commit comments

Comments
 (0)