|
42 | 42 |
|
43 | 43 |
|
44 | 44 | __all__ = [ |
| 45 | + "Action", |
| 46 | + "Command", |
| 47 | + "Command", |
45 | 48 | "EGraph", |
46 | | - "Module", |
47 | | - "function", |
48 | | - "ruleset", |
49 | | - "method", |
50 | | - "relation", |
51 | 49 | "Expr", |
| 50 | + "Fact", |
| 51 | + "Fact", |
| 52 | + "GraphvizKwargs", |
| 53 | + "Module", |
| 54 | + "RewriteOrRule", |
| 55 | + "Ruleset", |
| 56 | + "Schedule", |
52 | 57 | "Unit", |
53 | | - "rewrite", |
| 58 | + "_BirewriteBuilder", |
| 59 | + "_EqBuilder", |
| 60 | + "_NeBuilder", |
| 61 | + "_RewriteBuilder", |
| 62 | + "_SetBuilder", |
| 63 | + "_UnionBuilder", |
| 64 | + "action_command", |
54 | 65 | "birewrite", |
55 | | - "eq", |
56 | | - "ne", |
57 | | - "panic", |
58 | | - "let", |
| 66 | + "check", |
| 67 | + "check_eq", |
59 | 68 | "constant", |
60 | 69 | "delete", |
61 | | - "subsume", |
62 | | - "union", |
63 | | - "set_", |
64 | | - "rule", |
65 | | - "var", |
66 | | - "vars_", |
67 | | - "Fact", |
68 | | - "expr_parts", |
| 70 | + "eq", |
69 | 71 | "expr_action", |
70 | 72 | "expr_fact", |
71 | | - "action_command", |
72 | | - "Schedule", |
| 73 | + "expr_parts", |
| 74 | + "function", |
| 75 | + "let", |
| 76 | + "method", |
| 77 | + "ne", |
| 78 | + "panic", |
| 79 | + "relation", |
| 80 | + "rewrite", |
| 81 | + "rule", |
| 82 | + "ruleset", |
73 | 83 | "run", |
74 | 84 | "seq", |
75 | | - "Command", |
| 85 | + "set_", |
76 | 86 | "simplify", |
| 87 | + "subsume", |
| 88 | + "union", |
77 | 89 | "unstable_combine_rulesets", |
78 | | - "check", |
79 | | - "GraphvizKwargs", |
80 | | - "Ruleset", |
81 | | - "_RewriteBuilder", |
82 | | - "_BirewriteBuilder", |
83 | | - "_EqBuilder", |
84 | | - "_NeBuilder", |
85 | | - "_SetBuilder", |
86 | | - "_UnionBuilder", |
87 | | - "RewriteOrRule", |
88 | | - "Fact", |
89 | | - "Action", |
90 | | - "Command", |
91 | | - "check_eq", |
| 90 | + "var", |
| 91 | + "vars_", |
92 | 92 | ] |
93 | 93 |
|
94 | 94 | T = TypeVar("T") |
@@ -146,20 +146,29 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR: |
146 | 146 | return EGraph().extract(x) |
147 | 147 |
|
148 | 148 |
|
149 | | -def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None) -> EGraph: |
| 149 | +def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None, *, add_second=True, display=False) -> EGraph: |
150 | 150 | """ |
151 | 151 | Verifies that two expressions are equal after running the schedule. |
| 152 | +
|
| 153 | + If add_second is true, then the second expression is added to the egraph before running the schedule. |
152 | 154 | """ |
153 | 155 | egraph = EGraph() |
154 | 156 | x_var = egraph.let("__check_eq_x", x) |
155 | | - y_var = egraph.let("__check_eq_y", y) |
| 157 | + y_var: EXPR = egraph.let("__check_eq_y", y) if add_second else y |
156 | 158 | if schedule: |
157 | | - egraph.run(schedule) |
| 159 | + try: |
| 160 | + egraph.run(schedule) |
| 161 | + finally: |
| 162 | + if display: |
| 163 | + egraph.display() |
158 | 164 | fact = eq(x_var).to(y_var) |
159 | 165 | try: |
160 | 166 | egraph.check(fact) |
161 | 167 | except bindings.EggSmolError as err: |
162 | | - raise AssertionError(f"Failed {eq(x).to(y)}\n -> {ne(egraph.extract(x)).to(egraph.extract(y))})") from err |
| 168 | + if display: |
| 169 | + egraph.display() |
| 170 | + err.add_note(f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})") |
| 171 | + raise |
163 | 172 | return egraph |
164 | 173 |
|
165 | 174 |
|
@@ -598,8 +607,9 @@ def _generate_class_decls( # noqa: C901,PLR0912 |
598 | 607 | unextractable=unextractable, |
599 | 608 | subsume=subsume, |
600 | 609 | ) |
601 | | - except ValueError as e: |
602 | | - raise ValueError(f"Error processing {cls_name}.{method_name}: {e}") from e |
| 610 | + except Exception as e: |
| 611 | + e.add_note(f"Error processing {cls_name}.{method_name}") |
| 612 | + raise |
603 | 613 |
|
604 | 614 | if not builtin and not isinstance(ref, InitRef) and not mutates: |
605 | 615 | add_default_funcs.append(add_rewrite) |
@@ -1389,10 +1399,16 @@ def to_json() -> str: |
1389 | 1399 |
|
1390 | 1400 | egraphs = [to_json()] |
1391 | 1401 | i = 0 |
1392 | | - while self.run(schedule or 1).updated and i < max: |
1393 | | - i += 1 |
| 1402 | + # Always visualize, even if we encounter an error |
| 1403 | + try: |
| 1404 | + while (self.run(schedule or 1).updated) and i < max: |
| 1405 | + i += 1 |
| 1406 | + egraphs.append(to_json()) |
| 1407 | + except: |
1394 | 1408 | egraphs.append(to_json()) |
1395 | | - VisualizerWidget(egraphs=egraphs).display_or_open() |
| 1409 | + raise |
| 1410 | + finally: |
| 1411 | + VisualizerWidget(egraphs=egraphs).display_or_open() |
1396 | 1412 |
|
1397 | 1413 | @classmethod |
1398 | 1414 | def current(cls) -> EGraph: |
|
0 commit comments