Skip to content

Commit 2691387

Browse files
Add ability to override hash
1 parent 1dcb6ca commit 2691387

File tree

4 files changed

+23
-1
lines changed

4 files changed

+23
-1
lines changed

python/egglog/egraph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@
110110
"__weakref__",
111111
"__orig_bases__",
112112
"__annotations__",
113-
"__hash__",
114113
"__qualname__",
115114
"__firstlineno__",
116115
"__static_attributes__",

python/egglog/exp/array_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ def __le__(self, other: IntLike) -> Boolean: ...
154154
def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
155155
...
156156

157+
# add a hash so that this test can pass
158+
# https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
159+
@method(preserve=True)
160+
def __hash__(self) -> int:
161+
egraph = _get_current_egraph()
162+
egraph.register(self)
163+
egraph.run(array_api_schedule)
164+
simplified = egraph.extract(self)
165+
return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
166+
157167
# TODO: Fix this?
158168
# Make != always return a Bool, so that numpy.unique works on a tuple of ints
159169
# In _unique1d

python/egglog/runtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,8 @@ def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
567567
self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
568568

569569
def __hash__(self) -> int:
570+
if (method := _get_expr_method(self, "__hash__")) is not None:
571+
return cast("int", cast("Any", method()))
570572
return hash(self.__egg_typed_expr__)
571573

572574
# Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a

python/tests/test_high_level.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,17 @@ def test_type_param_sub():
947947
assert str(V[Unit]) == str(Vec[Unit] | int) # type: ignore[misc]
948948

949949

950+
def test_override_hash(self):
951+
class A(Expr):
952+
def __init__(self) -> None: ...
953+
954+
@method(preserve=True)
955+
def __hash__(self) -> int:
956+
return 42
957+
958+
assert hash(A()) == 42
959+
960+
950961
EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py"))
951962

952963

0 commit comments

Comments
 (0)