Skip to content

Commit eb118bb

Browse files
Switch to pooled objects
1 parent 8f948d1 commit eb118bb

File tree

1 file changed

+38
-19
lines changed

1 file changed

+38
-19
lines changed

python/egglog/declarations.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from __future__ import annotations
88

99
from dataclasses import dataclass, field
10-
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypeVar, Union, runtime_checkable
10+
from functools import cached_property
11+
from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
12+
from weakref import WeakValueDictionary
1113

1214
from typing_extensions import Self, assert_never
1315

@@ -569,29 +571,46 @@ class CallDecl:
569571
# Used for pretty printing classmethod calls with type parameters
570572
bound_tp_params: tuple[JustTypeRef, ...] | None = None
571573

574+
# pool objects for faster __eq__
575+
_args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})
576+
577+
def __new__(cls, *args: object, **kwargs: object) -> Self:
578+
"""
579+
Pool CallDecls so that they can be compared by identity more quickly.
580+
581+
Neccessary bc we search for common parents when serializing CallDecl trees to egglog to
582+
only serialize each sub-tree once.
583+
"""
584+
# normalize the args/kwargs to a tuple so that they can be compared
585+
callable = args[0] if args else kwargs["callable"]
586+
args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
587+
bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params")
588+
589+
normalized_args = (callable, args_, bound_tp_params)
590+
try:
591+
return cast(Self, cls._args_to_value[normalized_args])
592+
except KeyError:
593+
res = super().__new__(cls)
594+
cls._args_to_value[normalized_args] = res
595+
return res
596+
572597
def __post_init__(self) -> None:
573598
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
574599
msg = "Cannot bind type parameters to a non-class method callable."
575600
raise ValueError(msg)
576601

577-
# def __hash__(self) -> int:
578-
# return self._cached_hash
579-
580-
# @cached_property
581-
# def _cached_hash(self) -> int:
582-
# return hash((self.callable, self.args, self.bound_tp_params))
583-
584-
# def __eq__(self, other: object) -> bool:
585-
# # Override eq to use cached hash for perf
586-
# if not isinstance(other, CallDecl):
587-
# return False
588-
# if hash(self) != hash(other):
589-
# return False
590-
# return (
591-
# self.callable == other.callable
592-
# and self.args == other.args
593-
# and self.bound_tp_params == other.bound_tp_params
594-
# )
602+
def __hash__(self) -> int:
603+
return self._cached_hash
604+
605+
@cached_property
606+
def _cached_hash(self) -> int:
607+
return hash((self.callable, self.args, self.bound_tp_params))
608+
609+
def __eq__(self, other: object) -> bool:
610+
return self is other
611+
612+
def __ne__(self, other: object) -> bool:
613+
return self is not other
595614

596615

597616
@dataclass(frozen=True)

0 commit comments

Comments
 (0)