|
16 | 16 | ClassVar, |
17 | 17 | Generic, |
18 | 18 | Literal, |
19 | | - Never, |
20 | 19 | TypeAlias, |
21 | 20 | TypedDict, |
22 | 21 | TypeVar, |
|
26 | 25 | ) |
27 | 26 |
|
28 | 27 | import graphviz |
29 | | -from typing_extensions import ParamSpec, Self, Unpack, assert_never |
| 28 | +from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never |
30 | 29 |
|
31 | 30 | from . import bindings |
32 | 31 | from .conversion import * |
|
36 | 35 | from .pretty import pretty_decl |
37 | 36 | from .runtime import * |
38 | 37 | from .thunk import * |
| 38 | +from .version_compat import * |
39 | 39 |
|
40 | 40 | if TYPE_CHECKING: |
41 | 41 | from .builtins import String, Unit |
@@ -169,8 +169,9 @@ def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, ad |
169 | 169 | except bindings.EggSmolError as err: |
170 | 170 | if display: |
171 | 171 | egraph.display() |
172 | | - err.add_note(f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})") |
173 | | - raise |
| 172 | + raise add_note( |
| 173 | + f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})", err |
| 174 | + ) from None |
174 | 175 | return egraph |
175 | 176 |
|
176 | 177 |
|
@@ -492,8 +493,7 @@ def _generate_class_decls( # noqa: C901,PLR0912 |
492 | 493 | reverse_args=reverse_args, |
493 | 494 | ) |
494 | 495 | except Exception as e: |
495 | | - e.add_note(f"Error processing {cls_name}.{method_name}") |
496 | | - raise |
| 496 | + raise add_note(f"Error processing {cls_name}.{method_name}", e) from None |
497 | 497 |
|
498 | 498 | if not builtin and not isinstance(ref, InitRef) and not mutates: |
499 | 499 | add_default_funcs.append(add_rewrite) |
@@ -569,16 +569,11 @@ def _fn_decl( |
569 | 569 | if not isinstance(fn, FunctionType): |
570 | 570 | raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}") |
571 | 571 |
|
572 | | - hint_globals = fn.__globals__.copy() |
573 | | - # Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block |
574 | | - # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/ |
575 | | - if "Callable" not in hint_globals: |
576 | | - hint_globals["Callable"] = Callable |
577 | 572 | # Instead of passing both globals and locals, just pass the globals. Otherwise, for some reason forward references |
578 | 573 | # won't be resolved correctly |
579 | 574 | # We need this to be false so it returns "__forward_value__" https://github.com/python/cpython/blob/440ed18e08887b958ad50db1b823e692a747b671/Lib/typing.py#L919 |
580 | 575 | # https://github.com/egraphs-good/egglog-python/issues/210 |
581 | | - hint_globals.update(hint_locals) |
| 576 | + hint_globals = {**fn.__globals__, **hint_locals} |
582 | 577 | hints = get_type_hints(fn, hint_globals) |
583 | 578 |
|
584 | 579 | params = list(signature(fn).parameters.values()) |
@@ -632,7 +627,7 @@ def _fn_decl( |
632 | 627 | ) |
633 | 628 | decls |= merged |
634 | 629 |
|
635 | | - # defer this in generator so it doesnt resolve for builtins eagerly |
| 630 | + # defer this in generator so it doesn't resolve for builtins eagerly |
636 | 631 | args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True)) |
637 | 632 | res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef |
638 | 633 | res_thunk: Callable[[], object] |
@@ -676,7 +671,7 @@ def _fn_decl( |
676 | 671 | ) |
677 | 672 | res_ref = ref |
678 | 673 | decls.set_function_decl(ref, decl) |
679 | | - res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset) |
| 674 | + res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}") |
680 | 675 | return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume) |
681 | 676 |
|
682 | 677 |
|
@@ -1045,8 +1040,7 @@ def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._ExtractReport: |
1045 | 1040 | bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n)))) |
1046 | 1041 | ) |
1047 | 1042 | except BaseException as e: |
1048 | | - e.add_note("Extracting: " + str(expr)) |
1049 | | - raise |
| 1043 | + raise add_note("Extracting: " + str(expr), e) # noqa: B904 |
1050 | 1044 | extract_report = self._egraph.extract_report() |
1051 | 1045 | if not extract_report: |
1052 | 1046 | msg = "No extract report saved" |
|
0 commit comments