Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,27 @@ class Op(MetaObject):
as nodes with these Ops must be rebuilt even if the input types haven't changed.
"""

__props__: tuple[str, ...] = ()
"""
A tuple of attribute names that define the properties of this Op instance.

These properties are used for equality comparison, hashing, and string representation.
Subclasses should override this with the names of attributes that affect the
computation performed by the Op.

Examples
========

.. code-block:: python

class MyOp(Op):
__props__ = ("param1", "param2")

def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
"""

def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.

Expand Down Expand Up @@ -319,6 +340,40 @@ def __call__(
def __ne__(self, other: Any) -> bool:
return not (self == other)

def _props(self) -> tuple:
"""Return a tuple of properties that define this Op instance.

This method returns a tuple containing the values of all properties
listed in the __props__ attribute, if it exists. These properties
are used for equality comparison and hashing.

Returns
-------
tuple
A tuple of property values in the order they appear in __props__.
Returns an empty tuple if __props__ is not defined.
"""
if hasattr(self, "__props__"):
return tuple(getattr(self, prop) for prop in self.__props__)
return ()

def _props_dict(self) -> dict:
"""Return a dictionary mapping property names to their values.

This method returns a dictionary where keys are property names from
the __props__ attribute and values are the corresponding property values.
This is useful in optimization to swap ops that should have the same props.

Returns
-------
dict
A dictionary mapping property names to values.
Returns an empty dict if __props__ is not defined.
"""
if hasattr(self, "__props__"):
return {prop: getattr(self, prop) for prop in self.__props__}
return {}

# Convenience so that subclass implementers don't have to import utils
# just to self.add_tag_trace
add_tag_trace = staticmethod(add_tag_trace)
Expand Down
20 changes: 0 additions & 20 deletions pytensor/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,6 @@ def __new__(cls, name, bases, dct):
if not all(isinstance(p, str) for p in props):
raise TypeError("elements of __props__ have to be strings")

def _props(self):
"""
Tuple of properties of all attributes
"""
return tuple(getattr(self, a) for a in props)

dct["_props"] = _props

def _props_dict(self):
"""This return a dict of all ``__props__`` key-> value.

This is useful in optimization to swap op that should have the
same props. This help detect error that the new op have at
least all the original props.

"""
return {a: getattr(self, a) for a in props}

dct["_props_dict"] = _props_dict

if "__hash__" not in dct:

def __hash__(self):
Expand Down
8 changes: 8 additions & 0 deletions tests/graph/rewriting/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def __eq__(self, other):
def __hash__(self):
return hash((type(self), self.a))

# Override to remove __props__ and make this Op non-etuplizable (atomic)
def __getattribute__(self, name):
if name == "__props__":
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '__props__'"
)
return super().__getattribute__(name)


def test_cons():
x_pt = pt.vector("x")
Expand Down
Loading