Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ def __init__(
tracks=(),
get_nodes=None,
values_eq_approx=None,
allow_cast=True,
):
"""

Expand All @@ -1572,6 +1573,10 @@ def __init__(
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this rewrite.
values_eq_approx
TODO
allow_cast
Automatically cast the output of the rewrite whenever new and old types differ

Notes
-----
Expand All @@ -1586,6 +1591,7 @@ def __init__(
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
self.values_eq_approx = values_eq_approx
self.allow_cast = allow_cast
if isinstance(in_pattern, list | tuple):
self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
Expand Down Expand Up @@ -1653,14 +1659,20 @@ def transform(self, fgraph, node, get_nodes=True):
return False

if ret.owner:
if not (
len(node.outputs) == len(ret.owner.outputs)
and all(
if len(node.outputs) != len(ret.owner.outputs):
return
if len(node.outputs) > 1:
if not all(
o.type.is_super(new_o.type)
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
)
):
return False
):
return False
else:
out_dtype = node.outputs[0].type.dtype
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype:
ret = pytensor.tensor.basic.cast(ret, out_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all types have a dtype, we should check it's a TensorType before even trying to access dtype and doing stuff with it. I would perhaps write like this:

The whole logic is weird though with the if ret.owner, why do we care about the type of outputs we're not replacing. It's actually dangerous to try to replace only one of them without the user consent. Since this is WIP I would change to if len(node.outputs) != 1: return False, before we try to unify.

Then here we just have to worry about the final else branch below:

[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
  if not (
    self.allow_cast 
    and isinstance(old_out.type, TensorType) 
    and isinstance(ret.type, TensorType)
  ):
    return False

  # Try to cast
  ret = ret.astype(old_out.type.dtype)
  if not old_out.type.is_super(ret.type):
    return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to replace as you suggest but I am not sure how to fit it within the rest. This is the current code:

        if ret.owner:
            if not (
                len(node.outputs) == len(ret.owner.outputs)
                and all(
                    o.type.is_super(new_o.type)
                    for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
                )
            ):
                return False
        else:
            # ret is just an input variable
            assert len(node.outputs) == 1
            if not node.outputs[0].type.is_super(ret.type):
                return False

Copy link
Member

@ricardoV94 ricardoV94 Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only need what I wrote, above, template something like this

def transform(...):

...

if node.op != self.op:
    return False

if len(node.outputs) != 1:
  # PatternNodeRewriter doesn't support replacing multi-output nodes  
  return False

...

if not self.allow_multiple_clients:
  ...


# New logic

[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
  # Type doesn't match
  if not (
    self.allow_cast 
    and isinstance(old_out.type, TensorType) 
    and isinstance(ret.type, TensorType)
  ):
    return False

  # Try to cast tensors
  ret = ret.astype(old_out.type.dtype)
  if not old_out.type.is_super(ret.type):
    # Still doesn't match
    return False

return [ret]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure PatternNodeRewriter is supposed to only work with single inputs? I get the following error:

    def test_patternsub_different_output_lengths():
        # Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
        ps = PatternNodeRewriter(
            (op1, "x"),
            ("x"),
            name="ps",
        )
        rewriter = in2out(ps)
    
        x = MyVariable("x")
        e1, e2 = op_multiple_outputs(x)
        o = op1(e1)
    
        fgraph = FunctionGraph(inputs=[x], outputs=[o])
        rewriter.rewrite(fgraph)
>       assert fgraph.outputs[0].owner.op == op1
E       assert OpMultipleOutputs == op1
E        +  where OpMultipleOutputs = OpMultipleOutputs(x).op
E        +    where OpMultipleOutputs(x) = OpMultipleOutputs.0.owner

Copy link
Member

@ricardoV94 ricardoV94 Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that test makes sense. It's like saying you don't want to replace log(exp(x), if x comes from a multi-output node. We usually don't care about the provenance of a root variable in a rewrite. Nothing in that rewrite cares about op_multiple_outputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. It sorts of makes sense to me but I know too little of the PyTensor internals to fully understand.
Can you propose a quick way to modify/replace the test with one where it refuses to replace OpMultipleOutputs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you push your changes (if you haven't already), I can push the new test on top of it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have pushed all my changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a commit that changes the behavior of the test, have a look and let me know if there's anything else missing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ready to go. Thank you for your help.

if not node.outputs[0].type.is_super(ret.owner.outputs[0].type):
return False
else:
# ret is just an input variable
assert len(node.outputs) == 1
Expand Down
26 changes: 20 additions & 6 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
bitwise_and,
bitwise_or,
bitwise_xor,
cast,
conj,
cosh,
deg2rad,
Expand Down Expand Up @@ -4115,24 +4116,37 @@ def test_exp_over_1_plus_exp(self):
def test_local_1msigmoid(self):
m = self.get_mode(excluding=["fusion", "inplace"])
x = fmatrix()
xd = dmatrix()

# Test `exp_over_1_plus_exp`
f = pytensor.function([x], 1 - exp(x) / (1 + exp(x)), mode=m)
# FIXME: PatternNodeRewriter does not copy stack trace
# (see https://github.com/Theano/Theano/issues/4581)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])

# Test `inv_1_plus_exp`
f = pytensor.function([x], 1 - pt.fill(x, 1.0) / (1 + exp(-x)), mode=m)
# assert check_stack_trace(f, ops_to_check=[neg, sigmoid])
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
assert equal_computations(f.maker.fgraph.outputs, [sigmoid(-x)])

# Test float constant
f = pytensor.function(
[x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m
)
assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid]
for out, expected in [
(np.array(1.0, "float32") - sigmoid(x), sigmoid(-x)),
(np.array(1.0, "float64") - pt.sigmoid(x), cast(sigmoid(-x), "float64")),
(np.array(1.0, "float32") - sigmoid(xd), sigmoid(-xd)),
(np.array([[1.0]], "float64") - sigmoid(xd), sigmoid(-xd)),
]:
f = pytensor.function([x, xd], out, m, on_unused_input="ignore")
f_outs = f.maker.fgraph.outputs
assert equal_computations(
f_outs, [expected]
), "Expression:\n{}rewritten as:\n{}expected:\n{}".format(
*(
pytensor.dprint(expr, print_type=True, file="str")
for expr in (out, f_outs, expected)
)
)

def test_local_sigm_times_exp(self):
"""
Expand Down