|
7 | 7 | from __future__ import annotations
|
8 | 8 |
|
9 | 9 | from dataclasses import dataclass, field
|
10 |
| -from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypeVar, Union, runtime_checkable |
| 10 | +from functools import cached_property |
| 11 | +from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable |
| 12 | +from weakref import WeakValueDictionary |
11 | 13 |
|
12 | 14 | from typing_extensions import Self, assert_never
|
13 | 15 |
|
@@ -569,29 +571,46 @@ class CallDecl:
|
569 | 571 | # Used for pretty printing classmethod calls with type parameters
|
570 | 572 | bound_tp_params: tuple[JustTypeRef, ...] | None = None
|
571 | 573 |
|
| 574 | + # pool objects for faster __eq__ |
| 575 | + _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({}) |
| 576 | + |
| 577 | + def __new__(cls, *args: object, **kwargs: object) -> Self: |
| 578 | + """ |
| 579 | + Pool CallDecls so that they can be compared by identity more quickly. |
| 580 | +
|
| 581 | + Neccessary bc we search for common parents when serializing CallDecl trees to egglog to |
| 582 | + only serialize each sub-tree once. |
| 583 | + """ |
| 584 | + # normalize the args/kwargs to a tuple so that they can be compared |
| 585 | + callable = args[0] if args else kwargs["callable"] |
| 586 | + args_ = args[1] if len(args) > 1 else kwargs.get("args", ()) |
| 587 | + bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params") |
| 588 | + |
| 589 | + normalized_args = (callable, args_, bound_tp_params) |
| 590 | + try: |
| 591 | + return cast(Self, cls._args_to_value[normalized_args]) |
| 592 | + except KeyError: |
| 593 | + res = super().__new__(cls) |
| 594 | + cls._args_to_value[normalized_args] = res |
| 595 | + return res |
| 596 | + |
572 | 597 | def __post_init__(self) -> None:
|
573 | 598 | if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
|
574 | 599 | msg = "Cannot bind type parameters to a non-class method callable."
|
575 | 600 | raise ValueError(msg)
|
576 | 601 |
|
577 |
| - # def __hash__(self) -> int: |
578 |
| - # return self._cached_hash |
579 |
| - |
580 |
| - # @cached_property |
581 |
| - # def _cached_hash(self) -> int: |
582 |
| - # return hash((self.callable, self.args, self.bound_tp_params)) |
583 |
| - |
584 |
| - # def __eq__(self, other: object) -> bool: |
585 |
| - # # Override eq to use cached hash for perf |
586 |
| - # if not isinstance(other, CallDecl): |
587 |
| - # return False |
588 |
| - # if hash(self) != hash(other): |
589 |
| - # return False |
590 |
| - # return ( |
591 |
| - # self.callable == other.callable |
592 |
| - # and self.args == other.args |
593 |
| - # and self.bound_tp_params == other.bound_tp_params |
594 |
| - # ) |
| 602 | + def __hash__(self) -> int: |
| 603 | + return self._cached_hash |
| 604 | + |
| 605 | + @cached_property |
| 606 | + def _cached_hash(self) -> int: |
| 607 | + return hash((self.callable, self.args, self.bound_tp_params)) |
| 608 | + |
| 609 | + def __eq__(self, other: object) -> bool: |
| 610 | + return self is other |
| 611 | + |
| 612 | + def __ne__(self, other: object) -> bool: |
| 613 | + return self is not other |
595 | 614 |
|
596 | 615 |
|
597 | 616 | @dataclass(frozen=True)
|
|
0 commit comments