Skip to content

Commit c54a34a

Browse files
author
Luca Citi
committed
Implemented allow_cast in PatternNodeRewriter
to allow rewrites that would otherwise fail when the new and old dtype differ. Example: `np.array(1., "float64") - sigmoid(x)` cannot be rewritten as `sigmoid(-x)` (where x is an fmatrix) because the type would change. This commit allows an automatic cast to be added so the expression is rewritten as `cast(sigmoid(-x), "float64")`. Relevant tests added.
1 parent 0bb15f9 commit c54a34a

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,7 @@ def __init__(
15501550
tracks=(),
15511551
get_nodes=None,
15521552
values_eq_approx=None,
1553+
allow_cast=True,
15531554
):
15541555
"""
15551556
@@ -1572,6 +1573,10 @@ def __init__(
15721573
If you provide `tracks`, you must provide this parameter. It must be a
15731574
function that takes the tracked node and returns a list of nodes on
15741575
which we will try this rewrite.
1576+
values_eq_approx
1577+
TODO
1578+
allow_cast
1579+
Automatically cast the output of the rewrite whenever new and old types differ
15751580
15761581
Notes
15771582
-----
@@ -1586,6 +1591,7 @@ def __init__(
15861591
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
15871592
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
15881593
self.values_eq_approx = values_eq_approx
1594+
self.allow_cast = allow_cast
15891595
if isinstance(in_pattern, list | tuple):
15901596
self.op = self.in_pattern[0]
15911597
elif isinstance(in_pattern, dict):
@@ -1653,14 +1659,20 @@ def transform(self, fgraph, node, get_nodes=True):
16531659
return False
16541660

16551661
if ret.owner:
1656-
if not (
1657-
len(node.outputs) == len(ret.owner.outputs)
1658-
and all(
1662+
if len(node.outputs) != len(ret.owner.outputs):
1663+
return
1664+
if len(node.outputs) > 1:
1665+
if not all(
16591666
o.type.is_super(new_o.type)
16601667
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
1661-
)
1662-
):
1663-
return False
1668+
):
1669+
return False
1670+
else:
1671+
out_dtype = node.outputs[0].type.dtype
1672+
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype:
1673+
ret = pytensor.tensor.basic.cast(ret, out_dtype)
1674+
if not node.outputs[0].type.is_super(ret.owner.outputs[0].type):
1675+
return False
16641676
else:
16651677
# ret is just an input variable
16661678
assert len(node.outputs) == 1

tests/tensor/rewriting/test_math.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
bitwise_and,
5050
bitwise_or,
5151
bitwise_xor,
52+
cast,
5253
conj,
5354
cosh,
5455
deg2rad,
@@ -4115,24 +4116,37 @@ def test_exp_over_1_plus_exp(self):
41154116
def test_local_1msigmoid(self):
41164117
m = self.get_mode(excluding=["fusion", "inplace"])
41174118
x = fmatrix()
4119+
xd = dmatrix()
41184120

41194121
# Test `exp_over_1_plus_exp`
41204122
f = pytensor.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
41214123
# FIXME: PatternNodeRewriter does not copy stack trace
41224124
# (see https://github.com/Theano/Theano/issues/4581)
41234125
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
4124-
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
4126+
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])
41254127

41264128
# Test `inv_1_plus_exp`
41274129
f = pytensor.function([x], 1 - pt.fill(x, 1.0) / (1 + exp(-x)), mode=m)
41284130
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
4129-
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
4131+
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])
41304132

41314133
# Test float constant
4132-
f = pytensor.function(
4133-
[x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m
4134-
)
4135-
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
4134+
for out, expected in [
4135+
(np.array(1.0, "float32") - sigmoid(x), sigmoid(-x)),
4136+
(np.array(1.0, "float64") - pt.sigmoid(x), cast(sigmoid(-x), "float64")),
4137+
(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)),
4138+
(np.array([[1.0]], "float64") - sigmoid(xd), sigmoid(-xd)),
4139+
]:
4140+
f = pytensor.function([x, xd], out, m, on_unused_input="ignore")
4141+
f_outs = f.maker.fgraph.outputs
4142+
assert equal_computations(
4143+
f_outs, [expected]
4144+
), "Expression:\n{}rewritten as:\n{}expected:\n{}".format(
4145+
*(
4146+
pytensor.dprint(expr, print_type=True, file="str")
4147+
for expr in (out, f_outs, expected)
4148+
)
4149+
)
41364150

41374151
def test_local_sigm_times_exp(self):
41384152
"""

0 commit comments

Comments
 (0)