Skip to content

Commit 96122d1

Browse files
committed
Fix backwards compatibility in ScalarOp hash
1 parent ee107cb commit 96122d1

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

pytensor/scalar/basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,13 +1294,12 @@ def L_op(self, inputs, outputs, output_gradients):
12941294
return self.grad(inputs, output_gradients)
12951295

12961296
def __eq__(self, other):
1297-
test = type(self) is type(other) and getattr(
1297+
return type(self) is type(other) and getattr(
12981298
self, "output_types_preference", None
12991299
) == getattr(other, "output_types_preference", None)
1300-
return test
13011300

13021301
def __hash__(self):
1303-
return hash((type(self), getattr(self, "output_types_preference", 0)))
1302+
return hash((type(self), getattr(self, "output_types_preference", None)))
13041303

13051304
def __str__(self):
13061305
if hasattr(self, "name") and self.name:

tests/scalar/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor.graph.fg import FunctionGraph
99
from pytensor.link.c.basic import DualLinker
1010
from pytensor.scalar.basic import (
11+
EQ,
1112
ComplexError,
1213
Composite,
1314
InRange,
@@ -543,3 +544,18 @@ def test_grad_log10():
543544
b_grad = pytensor.gradient.grad(b, a)
544545
assert b.dtype == "float32"
545546
assert b_grad.dtype == "float32"
547+
548+
549+
def test_scalar_hash_default_output_type_preference():
550+
# Old hash used `getattr(self, "output_type_preference", 0)`
551+
# whereas equality used `getattr(self, "output_type_preference", None)`.
552+
# Since 27d797076668fbf0617654fd9b91f92ddb6737e6,
553+
# output_type_preference is always present (None if not specified),
554+
# which led to C-caching errors when comparing old cached Ops and fresh Ops,
555+
# as they evaluated equal but hashed differently
556+
557+
new_eq = EQ()
558+
old_eq = EQ()
559+
del old_eq.output_types_preference # mimic old Op
560+
assert new_eq == old_eq
561+
assert hash(new_eq) == hash(old_eq)

0 commit comments

Comments
 (0)