Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,13 +1294,12 @@ def L_op(self, inputs, outputs, output_gradients):
return self.grad(inputs, output_gradients)

def __eq__(self, other):
test = type(self) is type(other) and getattr(
return type(self) is type(other) and getattr(
self, "output_types_preference", None
) == getattr(other, "output_types_preference", None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the None default instead of 0 that was used in the hash. That was the issue

return test

def __hash__(self):
return hash((type(self), getattr(self, "output_types_preference", 0)))
return hash((type(self), getattr(self, "output_types_preference", None)))

def __str__(self):
if hasattr(self, "name") and self.name:
Expand Down
16 changes: 16 additions & 0 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import DualLinker
from pytensor.scalar.basic import (
EQ,
ComplexError,
Composite,
InRange,
Expand Down Expand Up @@ -543,3 +544,18 @@ def test_grad_log10():
b_grad = pytensor.gradient.grad(b, a)
assert b.dtype == "float32"
assert b_grad.dtype == "float32"


def test_scalar_hash_default_output_type_preference():
# Old hash used `getattr(self, "output_type_preference", 0)`
# whereas equality used `getattr(self, "output_type_preference", None)`.
# Since 27d797076668fbf0617654fd9b91f92ddb6737e6,
# output_type_preference is always present (None if not specified),
# which led to C-caching errors when comparing old cached Ops and fresh Ops,
# as they evaluated equal but hashed differently

new_eq = EQ()
old_eq = EQ()
del old_eq.output_types_preference # mimic old Op
assert new_eq == old_eq
assert hash(new_eq) == hash(old_eq)