|
13 | 13 | from typing_extensions import Self, assert_never |
14 | 14 |
|
15 | 15 | if TYPE_CHECKING: |
16 | | - from collections.abc import Callable, Iterable |
| 16 | + from collections.abc import Callable, Iterable, Mapping |
17 | 17 |
|
18 | 18 |
|
19 | 19 | __all__ = [ |
| 20 | + "replace_typed_expr", |
20 | 21 | "Declarations", |
21 | 22 | "DeclerationsLike", |
22 | 23 | "DelayedDeclerations", |
|
29 | 30 | "MethodRef", |
30 | 31 | "ClassMethodRef", |
31 | 32 | "FunctionRef", |
| 33 | + "UnnamedFunctionRef", |
32 | 34 | "ConstantRef", |
33 | 35 | "ClassVariableRef", |
34 | 36 | "PropertyRef", |
@@ -83,17 +85,14 @@ class DelayedDeclerations: |
83 | 85 |
|
84 | 86 | @property |
85 | 87 | def __egg_decls__(self) -> Declarations: |
| 88 | + thunk = self.__egg_decls_thunk__ |
86 | 89 | try: |
87 | | - return self.__egg_decls_thunk__() |
| 90 | + return thunk() |
88 | 91 | # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__` |
89 | 92 | # instead raise explicitly |
90 | 93 | except AttributeError as err: |
91 | 94 | msg = f"Cannot resolve declerations for {self}" |
92 | 95 | raise RuntimeError(msg) from err |
93 | | - # Might as well catch others too so we have more context when they raise |
94 | | - except Exception as err: # noqa: BLE001 |
95 | | - msg = f"Cannot resolve declerations for {self}" |
96 | | - raise RuntimeError(msg) from err |
97 | 96 |
|
98 | 97 |
|
99 | 98 | @runtime_checkable |
@@ -121,7 +120,8 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D |
121 | 120 |
|
122 | 121 | @dataclass |
123 | 122 | class Declarations: |
124 | | - _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict) |
| 123 | + # TODO: Replace with set of unnamed function decls |
| 124 | + _functions: dict[str | UnnamedFunctionRef, FunctionDecl | RelationDecl] = field(default_factory=dict) |
125 | 125 | _constants: dict[str, ConstantDecl] = field(default_factory=dict) |
126 | 126 | _classes: dict[str, ClassDecl] = field(default_factory=dict) |
127 | 127 | _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])}) |
@@ -323,9 +323,26 @@ def __str__(self) -> str: |
323 | 323 | ## |
324 | 324 |
|
325 | 325 |
|
| 326 | +@dataclass(frozen=True) |
| 327 | +class UnnamedFunctionRef: |
| 328 | + """ |
| 329 | + A reference to a function that doesn't have a name, but does have a body. |
| 330 | + """ |
| 331 | + |
| 332 | + arg_types: tuple[JustTypeRef, ...] |
| 333 | + arg_names: tuple[str, ...] |
| 334 | + res: TypedExprDecl |
| 335 | + |
| 336 | + @property |
| 337 | + def args(self) -> tuple[TypedExprDecl, ...]: |
| 338 | + return tuple( |
| 339 | + TypedExprDecl(tp, VarDecl(name, False)) for tp, name in zip(self.arg_types, self.arg_names, strict=True) |
| 340 | + ) |
| 341 | + |
| 342 | + |
326 | 343 | @dataclass(frozen=True) |
327 | 344 | class FunctionRef: |
328 | | - name: str |
| 345 | + name: str | UnnamedFunctionRef |
329 | 346 |
|
330 | 347 |
|
331 | 348 | @dataclass(frozen=True) |
@@ -460,6 +477,8 @@ def to_function_decl(self) -> FunctionDecl: |
460 | 477 | @dataclass(frozen=True) |
461 | 478 | class VarDecl: |
462 | 479 | name: str |
| 480 | + # Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix |
| 481 | + is_let: bool |
463 | 482 |
|
464 | 483 |
|
465 | 484 | @dataclass(frozen=True) |
@@ -566,6 +585,38 @@ def descendants(self) -> list[TypedExprDecl]: |
566 | 585 | return l |
567 | 586 |
|
568 | 587 |
|
| 588 | +def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl: |
| 589 | + """ |
| 590 | + Replace all the typed expressions in the given typed expression with the replacements. |
| 591 | + """ |
| 592 | + # keep track of the traversed expressions for memoization |
| 593 | + traversed: dict[TypedExprDecl, TypedExprDecl] = {} |
| 594 | + |
| 595 | + def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl: |
| 596 | + if typed_expr in traversed: |
| 597 | + return traversed[typed_expr] |
| 598 | + if typed_expr in replacements: |
| 599 | + res = replacements[typed_expr] |
| 600 | + else: |
| 601 | + match typed_expr.expr: |
| 602 | + case ( |
| 603 | + CallDecl(callable, args, bound_tp_params) |
| 604 | + | PartialCallDecl(CallDecl(callable, args, bound_tp_params)) |
| 605 | + ): |
| 606 | + new_args = tuple(_inner(a) for a in args) |
| 607 | + call_decl = CallDecl(callable, new_args, bound_tp_params) |
| 608 | + res = TypedExprDecl( |
| 609 | + typed_expr.tp, |
| 610 | + call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl), |
| 611 | + ) |
| 612 | + case _: |
| 613 | + res = typed_expr |
| 614 | + traversed[typed_expr] = res |
| 615 | + return res |
| 616 | + |
| 617 | + return _inner(typed_expr) |
| 618 | + |
| 619 | + |
569 | 620 | ## |
570 | 621 | # Schedules |
571 | 622 | ## |
|
0 commit comments