Skip to content

Commit 7c24e80

Browse files
Switch to using add_note
1 parent 37bf0e2 commit 7c24e80

File tree

4 files changed

+75
-52
lines changed

4 files changed

+75
-52
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ _This project uses semantic versioning_
1111
- Subsumes lambda functions after replacing
1212
- Add working loopnest test
1313
- Improve tracebacks on failing conversions.
14+
- Use `add_note` for exception to add more context, instead of raising a new exception, to make it easier to debug.
1415

1516
## 8.0.1 (2024-10-24)
1617

python/egglog/egraph.py

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,53 +42,53 @@
4242

4343

4444
__all__ = [
45+
"Action",
46+
"Command",
47+
"Command",
4548
"EGraph",
46-
"Module",
47-
"function",
48-
"ruleset",
49-
"method",
50-
"relation",
5149
"Expr",
50+
"Fact",
51+
"Fact",
52+
"GraphvizKwargs",
53+
"Module",
54+
"RewriteOrRule",
55+
"Ruleset",
56+
"Schedule",
5257
"Unit",
53-
"rewrite",
58+
"_BirewriteBuilder",
59+
"_EqBuilder",
60+
"_NeBuilder",
61+
"_RewriteBuilder",
62+
"_SetBuilder",
63+
"_UnionBuilder",
64+
"action_command",
5465
"birewrite",
55-
"eq",
56-
"ne",
57-
"panic",
58-
"let",
66+
"check",
67+
"check_eq",
5968
"constant",
6069
"delete",
61-
"subsume",
62-
"union",
63-
"set_",
64-
"rule",
65-
"var",
66-
"vars_",
67-
"Fact",
68-
"expr_parts",
70+
"eq",
6971
"expr_action",
7072
"expr_fact",
71-
"action_command",
72-
"Schedule",
73+
"expr_parts",
74+
"function",
75+
"let",
76+
"method",
77+
"ne",
78+
"panic",
79+
"relation",
80+
"rewrite",
81+
"rule",
82+
"ruleset",
7383
"run",
7484
"seq",
75-
"Command",
85+
"set_",
7686
"simplify",
87+
"subsume",
88+
"union",
7789
"unstable_combine_rulesets",
78-
"check",
79-
"GraphvizKwargs",
80-
"Ruleset",
81-
"_RewriteBuilder",
82-
"_BirewriteBuilder",
83-
"_EqBuilder",
84-
"_NeBuilder",
85-
"_SetBuilder",
86-
"_UnionBuilder",
87-
"RewriteOrRule",
88-
"Fact",
89-
"Action",
90-
"Command",
91-
"check_eq",
90+
"var",
91+
"vars_",
9292
]
9393

9494
T = TypeVar("T")
@@ -146,20 +146,29 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
146146
return EGraph().extract(x)
147147

148148

149-
def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None) -> EGraph:
149+
def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None, *, add_second=True, display=False) -> EGraph:
150150
"""
151151
Verifies that two expressions are equal after running the schedule.
152+
153+
If add_second is true, then the second expression is added to the egraph before running the schedule.
152154
"""
153155
egraph = EGraph()
154156
x_var = egraph.let("__check_eq_x", x)
155-
y_var = egraph.let("__check_eq_y", y)
157+
y_var: EXPR = egraph.let("__check_eq_y", y) if add_second else y
156158
if schedule:
157-
egraph.run(schedule)
159+
try:
160+
egraph.run(schedule)
161+
finally:
162+
if display:
163+
egraph.display()
158164
fact = eq(x_var).to(y_var)
159165
try:
160166
egraph.check(fact)
161167
except bindings.EggSmolError as err:
162-
raise AssertionError(f"Failed {eq(x).to(y)}\n -> {ne(egraph.extract(x)).to(egraph.extract(y))})") from err
168+
if display:
169+
egraph.display()
170+
err.add_note(f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})")
171+
raise
163172
return egraph
164173

165174

@@ -598,8 +607,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
598607
unextractable=unextractable,
599608
subsume=subsume,
600609
)
601-
except ValueError as e:
602-
raise ValueError(f"Error processing {cls_name}.{method_name}: {e}") from e
610+
except Exception as e:
611+
e.add_note(f"Error processing {cls_name}.{method_name}")
612+
raise
603613

604614
if not builtin and not isinstance(ref, InitRef) and not mutates:
605615
add_default_funcs.append(add_rewrite)
@@ -1389,10 +1399,16 @@ def to_json() -> str:
13891399

13901400
egraphs = [to_json()]
13911401
i = 0
1392-
while self.run(schedule or 1).updated and i < max:
1393-
i += 1
1402+
# Always visualize, even if we encounter an error
1403+
try:
1404+
while (self.run(schedule or 1).updated) and i < max:
1405+
i += 1
1406+
egraphs.append(to_json())
1407+
except:
13941408
egraphs.append(to_json())
1395-
VisualizerWidget(egraphs=egraphs).display_or_open()
1409+
raise
1410+
finally:
1411+
VisualizerWidget(egraphs=egraphs).display_or_open()
13961412

13971413
@classmethod
13981414
def current(cls) -> EGraph:

python/egglog/exp/array_api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,9 +1671,10 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
16711671
try:
16721672
expr_extracted = egraph.extract(expr)
16731673
except BaseException as inner_exc:
1674-
raise ValueError(f"Cannot simplify {expr}") from inner_exc
1675-
msg = f"Cannot simplify to primitive {expr_extracted}"
1676-
raise ValueError(msg) from exc
1674+
inner_exc.add_note(f"Cannot simplify {expr}")
1675+
raise
1676+
exc.add_note(f"Cannot simplify to primitive {expr_extracted}")
1677+
raise
16771678
return egraph.eval(extracted)
16781679

16791680
# string = (

python/egglog/runtime.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929

3030
__all__ = [
3131
"LIT_CLASS_NAMES",
32-
"resolve_callable",
33-
"resolve_type_annotation",
32+
"REFLECTED_BINARY_METHODS",
3433
"RuntimeClass",
3534
"RuntimeExpr",
3635
"RuntimeFunction",
37-
"REFLECTED_BINARY_METHODS",
36+
"resolve_callable",
37+
"resolve_type_annotation",
3838
]
3939

4040

@@ -199,7 +199,11 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
199199
}:
200200
raise AttributeError
201201

202-
cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
202+
try:
203+
cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
204+
except Exception as e:
205+
e.add_note(f"Error processing class {self.__egg_tp__.name}")
206+
raise
203207

204208
preserved_methods = cls_decl.preserved_methods
205209
if name in preserved_methods:
@@ -259,7 +263,8 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs:
259263
try:
260264
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
261265
except Exception as e:
262-
raise TypeError(f"Failed to find callable {self}") from e
266+
e.add_note(f"Failed to find callable {self}")
267+
raise
263268
decls = self.__egg_decls__.copy()
264269
# Special case function application bc we dont support variadic generics yet generally
265270
if signature == "fn-app":

0 commit comments

Comments
 (0)