Skip to content

Commit d4e8f73

Browse files
lcitiLuca CitiricardoV94
authored
Use stricter numerical tolerance in rewrites and allow casting in PatternNodeRewriter (#1526)
* Implemented allow_cast in PatternNodeRewriter to allow rewrites that would otherwise fail when the new and old dtype differ. Example: `np.array(1., "float64") - sigmoid(x)` cannot be rewritten as `sigmoid(-x)` (where x is an fmatrix) because the type would change. This commit allows an automatic cast to be added so the expression is rewritten as `cast(sigmoid(-x), "float64")`. Relevant tests added. * Added test cases for which issue #1497 fails * Changed PatternNodeRewriter::transform to allow types that do not contain dtype like MyType in the tests * Address #1497 by changing instances of np.isclose to a function isclose, which uses 10 ULPs by default * Addressed failed tests (with older python/numpy versions) * Addressed feedback by ricardoV94 * Test PatternNodeRewriter doesn't support multi-output nodes in pattern But it's fine if they're just root inputs --------- Co-authored-by: Luca Citi <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 0c13849 commit d4e8f73

File tree

6 files changed

+125
-32
lines changed

6 files changed

+125
-32
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,7 @@ def __init__(
15501550
tracks=(),
15511551
get_nodes=None,
15521552
values_eq_approx=None,
1553+
allow_cast=True,
15531554
):
15541555
"""
15551556
@@ -1572,6 +1573,10 @@ def __init__(
15721573
If you provide `tracks`, you must provide this parameter. It must be a
15731574
function that takes the tracked node and returns a list of nodes on
15741575
which we will try this rewrite.
1576+
values_eq_approx
1577+
TODO
1578+
allow_cast
1579+
Automatically cast the output of the rewrite whenever new and old types differ
15751580
15761581
Notes
15771582
-----
@@ -1586,6 +1591,7 @@ def __init__(
15861591
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
15871592
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
15881593
self.values_eq_approx = values_eq_approx
1594+
self.allow_cast = allow_cast
15891595
if isinstance(in_pattern, list | tuple):
15901596
self.op = self.in_pattern[0]
15911597
elif isinstance(in_pattern, dict):
@@ -1630,6 +1636,10 @@ def transform(self, fgraph, node, get_nodes=True):
16301636
if node.op != self.op:
16311637
return False
16321638

1639+
if len(node.outputs) != 1:
1640+
# PatternNodeRewriter doesn't support replacing multi-output nodes
1641+
return False
1642+
16331643
s = unify(self.in_pattern, node.out)
16341644

16351645
if s is False:
@@ -1652,19 +1662,20 @@ def transform(self, fgraph, node, get_nodes=True):
16521662
):
16531663
return False
16541664

1655-
if ret.owner:
1665+
[old_out] = node.outputs
1666+
if not old_out.type.is_super(ret.type):
1667+
# Type doesn't match
16561668
if not (
1657-
len(node.outputs) == len(ret.owner.outputs)
1658-
and all(
1659-
o.type.is_super(new_o.type)
1660-
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
1661-
)
1669+
self.allow_cast
1670+
and isinstance(old_out.type, pytensor.tensor.TensorType)
1671+
and isinstance(ret.type, pytensor.tensor.TensorType)
16621672
):
16631673
return False
1664-
else:
1665-
# ret is just an input variable
1666-
assert len(node.outputs) == 1
1667-
if not node.outputs[0].type.is_super(ret.type):
1674+
1675+
# Try to cast tensors
1676+
ret = ret.astype(old_out.type.dtype)
1677+
if not old_out.type.is_super(ret.type):
1678+
# Still doesn't match
16681679
return False
16691680

16701681
return [ret]

pytensor/tensor/rewriting/math.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2440,7 +2440,7 @@ def local_log1p(fgraph, node):
24402440
log_arg.owner.inputs, only_process_constants=True
24412441
)
24422442
# scalar_inputs are potentially dimshuffled and fill'd scalars
2443-
if scalars and np.allclose(np.sum(scalars), 1):
2443+
if scalars and isclose(np.sum(scalars), 1):
24442444
if nonconsts:
24452445
ninp = variadic_add(*nonconsts)
24462446
if ninp.dtype != log_arg.type.dtype:
@@ -3045,6 +3045,21 @@ def check_input(inputs):
30453045
return [ret]
30463046

30473047

3048+
def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
3049+
"""
3050+
3051+
Returns
3052+
-------
3053+
bool
3054+
True iff x is a constant close to ref (by default 10 ULPs).
3055+
3056+
"""
3057+
x = np.asarray(x)
3058+
if np.issubdtype(x.dtype, np.floating):
3059+
atol = atol + num_ulps * np.abs(np.spacing(x.dtype.type(ref)))
3060+
return np.allclose(x, ref, rtol=rtol, atol=atol)
3061+
3062+
30483063
def _skip_mul_1(r):
30493064
if r.owner and r.owner.op == mul:
30503065
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
@@ -3063,7 +3078,7 @@ def _is_1(expr):
30633078
"""
30643079
try:
30653080
v = get_underlying_scalar_constant_value(expr)
3066-
return np.isclose(v, 1)
3081+
return isclose(v, 1)
30673082
except NotScalarConstantError:
30683083
return False
30693084

@@ -3124,7 +3139,7 @@ def is_1pexp(t, only_process_constants=True):
31243139
scal_sum = scalars[0]
31253140
for s in scalars[1:]:
31263141
scal_sum = scal_sum + s
3127-
if np.allclose(scal_sum, 1):
3142+
if isclose(scal_sum, 1):
31283143
return False, maybe_exp.owner.inputs[0]
31293144
return None
31303145

@@ -3224,7 +3239,7 @@ def is_neg(var):
32243239
for idx, mul_input in enumerate(var_node.inputs):
32253240
try:
32263241
constant = get_underlying_scalar_constant_value(mul_input)
3227-
is_minus_1 = np.isclose(constant, -1)
3242+
is_minus_1 = isclose(constant, -1)
32283243
except NotScalarConstantError:
32293244
is_minus_1 = False
32303245
if is_minus_1:
@@ -3632,7 +3647,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
36323647
# scalar_inputs are potentially dimshuffled and fill'd scalars
36333648
if len(nonconsts) == 1:
36343649
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
3635-
if scalars_ and np.allclose(np.sum(scalars_), 1):
3650+
if scalars_ and isclose(np.sum(scalars_), 1):
36363651
out = [
36373652
alloc_like(
36383653
sigmoid(neg(nonconsts[0].owner.inputs[0])),

tests/graph/rewriting/test_basic.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
op_y,
4242
op_z,
4343
)
44+
from tests.unittest_tools import assert_equal_computations
4445

4546

4647
class AssertNoChanges(Feature):
@@ -725,22 +726,35 @@ def test_patternsub_invalid_dtype(out_pattern):
725726
assert e.type.is_super(fg.outputs[0].type)
726727

727728

728-
def test_patternsub_different_output_lengths():
729-
# Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
730-
ps = PatternNodeRewriter(
731-
(op1, "x"),
732-
("x"),
729+
def test_patternsub_multi_output_nodes():
730+
# Test that PatternNodeRewriter won't attempt to replace multi-output nodes
731+
multiple_op_ps = PatternNodeRewriter(
732+
(op_multiple_outputs, "x"),
733+
"x",
733734
name="ps",
734735
)
735-
rewriter = in2out(ps)
736+
737+
single_op_ps = PatternNodeRewriter(
738+
(op_y, "x"),
739+
"x",
740+
name="ps",
741+
)
742+
743+
rewriter = in2out(multiple_op_ps, single_op_ps)
736744

737745
x = MyVariable("x")
738746
e1, e2 = op_multiple_outputs(x)
739-
o = op1(e1)
747+
o1, o2 = op_y(e1), op_y(e2)
748+
749+
fgraph = FunctionGraph(inputs=[x], outputs=[e2, e1], copy_inputs=False)
750+
rewriter.rewrite(fgraph)
751+
# This shouldn't rewrite because PatternNodeRewriter has no way of specifying which output(s) are being matched
752+
assert_equal_computations(fgraph.outputs, [e2, e1])
740753

741-
fgraph = FunctionGraph(inputs=[x], outputs=[o])
754+
fgraph = FunctionGraph(inputs=[x], outputs=[o2, o1], copy_inputs=False)
742755
rewriter.rewrite(fgraph)
743-
assert fgraph.outputs[0].owner.op == op1
756+
# Having a variable that comes out of a multi-output node should be fine
757+
assert_equal_computations(fgraph.outputs, [e2, e1])
744758

745759

746760
class TestSequentialNodeRewriter:

tests/graph/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def make_node(self, *inputs):
107107

108108

109109
class MyOpMultipleOutputs(MyOp):
110+
def __init__(self, name, dmap=None, x=None):
111+
super().__init__(name=name, dmap=dmap, x=x, n_outs=2)
112+
110113
def make_node(self, input):
111114
outputs = [input.type(), input.type()]
112115
return Apply(self, [input], outputs)

tests/tensor/rewriting/test_math.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
bitwise_and,
5151
bitwise_or,
5252
bitwise_xor,
53+
cast,
5354
conj,
5455
cosh,
5556
deg2rad,
@@ -124,6 +125,7 @@
124125
dvector,
125126
fmatrices,
126127
fmatrix,
128+
fscalar,
127129
ftensor4,
128130
fvector,
129131
imatrices,
@@ -4069,25 +4071,36 @@ def test_exp_over_1_plus_exp(self):
40694071

40704072
def test_local_1msigmoid(self):
40714073
m = self.get_mode(excluding=["fusion", "inplace"])
4072-
x = fmatrix()
4074+
x = fscalar()
4075+
xd = dscalar()
40734076

40744077
# Test `exp_over_1_plus_exp`
40754078
f = pytensor.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
40764079
# FIXME: PatternNodeRewriter does not copy stack trace
40774080
# (see https://github.com/Theano/Theano/issues/4581)
40784081
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
4079-
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
4082+
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])
40804083

40814084
# Test `inv_1_plus_exp`
40824085
f = pytensor.function([x], 1 - pt.fill(x, 1.0) / (1 + exp(-x)), mode=m)
40834086
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
4084-
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
4087+
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])
40854088

40864089
# Test float constant
4087-
f = pytensor.function(
4088-
[x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m
4089-
)
4090-
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
4090+
for out, expected in [
4091+
(np.array(1.0, "float32") - sigmoid(x), sigmoid(-x)),
4092+
(np.array(1.0, "float64") - pt.sigmoid(x), cast(sigmoid(-x), "float64")),
4093+
(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)),
4094+
(np.array(1.0, "float64") - sigmoid(xd), sigmoid(-xd)),
4095+
(np.sum(1 / np.array([2, 3, 6], "float32")) - sigmoid(x), sigmoid(-x)),
4096+
(np.sum(1 / np.array([2, 3, 6], "float64")) - sigmoid(xd), sigmoid(-xd)),
4097+
(np.float32(1 - 9e-6) - sigmoid(x), np.float32(1 - 9e-6) - sigmoid(x)),
4098+
(np.float64(1 - 1e-9) - sigmoid(xd), np.float64(1 - 1e-9) - sigmoid(xd)),
4099+
]:
4100+
rewritten = rewrite_graph(
4101+
out, include=["canonicalize", "specialize", "stabilize"]
4102+
)
4103+
utt.assert_equal_computations([rewritten], [expected], original=out)
40914104

40924105
def test_local_sigm_times_exp(self):
40934106
"""
@@ -4235,7 +4248,8 @@ def test_log1msigm_to_softplus(self):
42354248
f(np.random.random((54, 11)).astype(config.floatX))
42364249

42374250
# Test close to 1
4238-
out = log(1.000001 - sigmoid(x))
4251+
x_dtype = np.dtype(x.dtype).type
4252+
out = log(np.nextafter(x_dtype(1), x_dtype(2)) - sigmoid(x))
42394253
f = pytensor.function([x], out, mode=self.m)
42404254
topo = f.maker.fgraph.toposort()
42414255
assert len(topo) == 2

tests/unittest_tools.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.compile.debugmode import str_diagnostic
1212
from pytensor.configdefaults import config
1313
from pytensor.gradient import verify_grad as orig_verify_grad
14+
from pytensor.graph.basic import equal_computations
1415
from pytensor.tensor.basic import as_tensor_variable
1516
from pytensor.tensor.math import _allclose
1617
from pytensor.tensor.math import add as pt_add
@@ -279,6 +280,41 @@ def assert_allclose(expected, value, rtol=None, atol=None):
279280
raise WrongValue(expected, value, rtol, atol)
280281

281282

283+
def assert_equal_computations(rewritten, expected, *args, original=None, **kwargs):
284+
"""
285+
Assert that `rewritten` computes the same as `expected`.
286+
287+
Parameters
288+
----------
289+
rewritten
290+
The expression after the rewrite pass.
291+
expected
292+
The reference expression to compare against.
293+
*args, **kwargs
294+
Extra arguments forwarded to equal_computations.
295+
original : optional
296+
If given, will be printed in the error message.
297+
"""
298+
__tracebackhide__ = True # Hide traceback for py.test
299+
300+
ok = equal_computations(rewritten, expected, *args, **kwargs)
301+
302+
if not ok:
303+
parts = []
304+
305+
def _dprint(expr):
306+
return pytensor.dprint(expr, print_type=True, file="str")
307+
308+
if original is not None:
309+
parts.append(f"\nOriginal:\n{_dprint(original)}")
310+
parts.append(f"\nRewritten:\n{_dprint(rewritten)}")
311+
parts.append(f"\nExpected:\n{_dprint(expected)}")
312+
313+
raise AssertionError("equal_computations failed\n" + "".join(parts))
314+
315+
return True
316+
317+
282318
class AttemptManyTimes:
283319
"""Decorator for unit tests that forces a unit test to be attempted
284320
multiple times. The test needs to pass a certain number of times for it to

0 commit comments

Comments
 (0)