Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ def __init__(
tracks=(),
get_nodes=None,
values_eq_approx=None,
allow_cast=True,
):
"""

Expand All @@ -1572,6 +1573,10 @@ def __init__(
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this rewrite.
values_eq_approx
TODO
allow_cast
Automatically cast the output of the rewrite whenever new and old types differ

Notes
-----
Expand All @@ -1586,6 +1591,7 @@ def __init__(
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
self.values_eq_approx = values_eq_approx
self.allow_cast = allow_cast
if isinstance(in_pattern, list | tuple):
self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
Expand Down Expand Up @@ -1653,14 +1659,22 @@ def transform(self, fgraph, node, get_nodes=True):
return False

if ret.owner:
if not (
len(node.outputs) == len(ret.owner.outputs)
and all(
if len(node.outputs) != len(ret.owner.outputs):
return
if len(node.outputs) > 1:
if not all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
)
):
return False
):
return False
else:
if self.allow_cast:
out_dtype = getattr(node.outputs[0].type, "dtype", None)
ret_dtype = getattr(ret.owner.outputs[0].type, "dtype", None)
if ret_dtype != out_dtype:
ret = pytensor.tensor.basic.cast(ret, out_dtype)
if not node.outputs[0].type.is_super(ret.owner.outputs[0].type):
return False
else:
# ret is just an input variable
assert len(node.outputs) == 1
Expand Down
33 changes: 26 additions & 7 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
bitwise_and,
bitwise_or,
bitwise_xor,
cast,
conj,
cosh,
deg2rad,
Expand Down Expand Up @@ -123,6 +124,7 @@
dvector,
fmatrices,
fmatrix,
fscalar,
ftensor4,
fvector,
imatrices,
Expand Down Expand Up @@ -4114,25 +4116,42 @@ def test_exp_over_1_plus_exp(self):

def test_local_1msigmoid(self):
m = self.get_mode(excluding=["fusion", "inplace"])
x = fmatrix()
x = fscalar()
xd = dscalar()

# Test `exp_over_1_plus_exp`
f = pytensor.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
# FIXME: PatternNodeRewriter does not copy stack trace
# (see https://github.com/Theano/Theano/issues/4581)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])

# Test `inv_1_plus_exp`
f = pytensor.function([x], 1 - pt.fill(x, 1.0) / (1 + exp(-x)), mode=m)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])

# Test float constant
f = pytensor.function(
[x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m
)
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
for out, expected in [
(np.array(1.0, "float32") - sigmoid(x), sigmoid(-x)),
(np.array(1.0, "float64") - pt.sigmoid(x), cast(sigmoid(-x), "float64")),
(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)),
(np.array(1.0, "float64") - sigmoid(xd), sigmoid(-xd)),
(np.sum(1 / np.array([2, 3, 6], "float32")) - sigmoid(x), sigmoid(-x)),
(np.sum(1 / np.array([2, 3, 6], "float64")) - sigmoid(xd), sigmoid(-xd)),
(np.float32(1 - 9e-6) - sigmoid(x), np.float32(1 - 9e-6) - sigmoid(x)),
(np.float64(1 - 1e-9) - sigmoid(xd), np.float64(1 - 1e-9) - sigmoid(xd)),
]:
f = pytensor.function([x, xd], out, m, on_unused_input="ignore")
f_outs = f.maker.fgraph.outputs
assert equal_computations(
f_outs, [expected]
), "Expression:\n{}rewritten as:\n{}expected:\n{}".format(
*(
pytensor.dprint(expr, print_type=True, file="str")
for expr in (out, f_outs, expected)
)
)

def test_local_sigm_times_exp(self):
"""
Expand Down