From bfe92a3867bffa6542dee9352c3278c2927231f0 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 22 Sep 2025 10:33:48 +0200 Subject: [PATCH] Fix backwards compatibility in ScalarOp hash --- pytensor/scalar/basic.py | 5 ++--- tests/scalar/test_basic.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index f28a1122c8..339da84cd1 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1294,13 +1294,12 @@ def L_op(self, inputs, outputs, output_gradients): return self.grad(inputs, output_gradients) def __eq__(self, other): - test = type(self) is type(other) and getattr( + return type(self) is type(other) and getattr( self, "output_types_preference", None ) == getattr(other, "output_types_preference", None) - return test def __hash__(self): - return hash((type(self), getattr(self, "output_types_preference", 0))) + return hash((type(self), getattr(self, "output_types_preference", None))) def __str__(self): if hasattr(self, "name") and self.name: diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 5aab9a95cc..2ab5a68c4c 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -8,6 +8,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.link.c.basic import DualLinker from pytensor.scalar.basic import ( + EQ, ComplexError, Composite, InRange, @@ -543,3 +544,18 @@ def test_grad_log10(): b_grad = pytensor.gradient.grad(b, a) assert b.dtype == "float32" assert b_grad.dtype == "float32" + + +def test_scalar_hash_default_output_type_preference(): + # Old hash used `getattr(self, "output_type_preference", 0)` + # whereas equality used `getattr(self, "output_type_preference", None)`. + # Since 27d797076668fbf0617654fd9b91f92ddb6737e6, + # output_type_preference is always present (None if not specified), + # which led to C-caching errors when comparing old cached Ops and fresh Ops, + # as they evaluated equal but hashed differently + + new_eq = EQ() + old_eq = EQ() + del old_eq.output_types_preference # mimic old Op + assert new_eq == old_eq + assert hash(new_eq) == hash(old_eq)