Skip to content

Commit 0fc9674

Browse files
Merge pull request #335 from egraphs-good/update-egglog
Bump Egglog version
2 parents 85d283c + c2ab937 commit 0fc9674

File tree

12 files changed

+351
-336
lines changed

12 files changed

+351
-336
lines changed

Cargo.lock

Lines changed: 101 additions & 215 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ crate-type = ["cdylib"]
1212
[dependencies]
1313
pyo3 = { version = "0.24.2", features = ["extension-module"] }
1414

15-
egglog = { git = "https://github.com/egraphs-good/egglog.git", rev = "5542549" }
15+
egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main", default-features = false }
1616
# egglog = { path = "../egg-smol" }
17-
egglog-bridge = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "cd51d04" }
18-
core-relations = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "cd51d04" }
19-
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", rev = "255b67a" }
17+
egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
18+
core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
19+
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "cli", default-features = false }
2020
egraph-serialize = { version = "0.2.0", features = ["serde", "graphviz"] }
2121
serde_json = "1.0.140"
2222
pyo3-log = "0.12.4"
@@ -28,7 +28,7 @@ rayon = "1.10.0"
2828

2929
# Use patched version of egglog in experimental
3030
[patch.'https://github.com/egraphs-good/egglog']
31-
# egglog = { git = "https://github.com/egraphs-good//egglog.git", rev = "d2fa5b733de0796fb187dc5a27e570d5644aa75a" }
31+
egglog = { git = "https://github.com/egraphs-good//egglog.git", branch = "main" }
3232
# egglog = { path = "../egg-smol" }
3333
# egglog = { git = "https://github.com/egraphs-good//egglog.git", rev = "5542549" }
3434

python/egglog/bindings.pyi

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,30 @@ from typing_extensions import final
77
__all__ = [
88
"ActionCommand",
99
"AddRuleset",
10-
"Best",
1110
"BiRewriteCommand",
1211
"Bool",
12+
"CSVPrintFunctionMode",
1313
"Call",
1414
"Change",
1515
"Check",
1616
"Constructor",
1717
"Datatype",
1818
"Datatypes",
19+
"DefaultPrintFunctionMode",
1920
"Delete",
2021
"EGraph",
2122
"EggSmolError",
2223
"EgglogSpan",
2324
"Eq",
2425
"Expr_",
2526
"Extract",
27+
"ExtractBest",
28+
"ExtractVariants",
2629
"Fact",
2730
"Fail",
2831
"Float",
2932
"Function",
33+
"FunctionCommand",
3034
"IdentSort",
3135
"Include",
3236
"Input",
@@ -35,10 +39,14 @@ __all__ = [
3539
"Lit",
3640
"NewSort",
3741
"Output",
42+
"OverallStatistics",
3843
"Panic",
3944
"PanicSpan",
4045
"Pop",
46+
"PrintAllFunctionsSize",
4147
"PrintFunction",
48+
"PrintFunctionOutput",
49+
"PrintFunctionSize",
4250
"PrintOverallStatistics",
4351
"PrintSize",
4452
"Push",
@@ -53,13 +61,13 @@ __all__ = [
5361
"RunConfig",
5462
"RunReport",
5563
"RunSchedule",
64+
"RunScheduleOutput",
5665
"RustSpan",
5766
"Saturate",
5867
"Schema",
5968
"Sequence",
6069
"SerializedEGraph",
6170
"Set",
62-
"SetOption",
6371
"Sort",
6472
"SrcFile",
6573
"String",
@@ -73,13 +81,18 @@ __all__ = [
7381
"Unit",
7482
"UnstableCombinedRuleset",
7583
"UserDefined",
84+
"UserDefinedCommandOutput",
85+
"UserDefinedOutput",
7686
"Var",
7787
"Variant",
78-
"Variants",
7988
]
8089

8190
@final
8291
class SerializedEGraph:
92+
@property
93+
def truncated_functions(self) -> list[str]: ...
94+
@property
95+
def discarded_functions(self) -> list[str]: ...
8396
def inline_leaves(self) -> None: ...
8497
def saturate_inline_leaves(self) -> None: ...
8598
def to_dot(self) -> str: ...
@@ -106,9 +119,7 @@ class EGraph:
106119
) -> None: ...
107120
def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ...
108121
def commands(self) -> str | None: ...
109-
def run_program(self, *commands: _Command) -> list[str]: ...
110-
def extract_report(self) -> _ExtractReport | None: ...
111-
def run_report(self) -> RunReport | None: ...
122+
def run_program(self, *commands: _Command) -> list[_CommandOutput]: ...
112123
def serialize(
113124
self,
114125
root_eclasses: list[_Expr],
@@ -356,6 +367,13 @@ class IdentSort:
356367
sort: str
357368
def __init__(self, ident: str, sort: str) -> None: ...
358369

370+
@final
371+
class UserDefinedCommandOutput: ...
372+
373+
@final
374+
class Function:
375+
name: str
376+
359377
@final
360378
class RunReport:
361379
updated: bool
@@ -375,20 +393,80 @@ class RunReport:
375393
rebuild_time_per_ruleset: dict[str, timedelta],
376394
) -> None: ...
377395

396+
##
397+
# Command Outputs
398+
##
399+
400+
@final
401+
class PrintFunctionSize:
402+
size: int
403+
def __init__(self, size: int) -> None: ...
404+
405+
@final
406+
class PrintAllFunctionsSize:
407+
sizes: list[tuple[str, int]]
408+
def __init__(self, sizes: list[tuple[str, int]]) -> None: ...
409+
378410
@final
379-
class Variants:
411+
class ExtractVariants:
380412
termdag: TermDag
381413
terms: list[_Term]
382414
def __init__(self, termdag: TermDag, terms: list[_Term]) -> None: ...
383415

384416
@final
385-
class Best:
417+
class ExtractBest:
386418
termdag: TermDag
387419
cost: int
388420
term: _Term
389421
def __init__(self, termdag: TermDag, cost: int, term: _Term) -> None: ...
390422

391-
_ExtractReport: TypeAlias = Variants | Best
423+
@final
424+
class OverallStatistics:
425+
report: RunReport
426+
def __init__(self, report: RunReport) -> None: ...
427+
428+
@final
429+
class RunScheduleOutput:
430+
report: RunReport
431+
def __init__(self, report: RunReport) -> None: ...
432+
433+
@final
434+
class PrintFunctionOutput:
435+
function: Function
436+
termdag: TermDag
437+
terms: list[tuple[_Term, _Term]]
438+
mode: _PrintFunctionMode
439+
def __init__(
440+
self, function: Function, termdag: TermDag, terms: list[tuple[_Term, _Term]], mode: _PrintFunctionMode
441+
) -> None: ...
442+
443+
@final
444+
class UserDefinedOutput:
445+
output: UserDefinedCommandOutput
446+
def __init__(self, output: UserDefinedCommandOutput) -> None: ...
447+
448+
_CommandOutput: TypeAlias = (
449+
PrintFunctionSize
450+
| PrintAllFunctionsSize
451+
| ExtractVariants
452+
| ExtractBest
453+
| OverallStatistics
454+
| RunScheduleOutput
455+
| PrintFunctionOutput
456+
| UserDefinedOutput
457+
)
458+
459+
##
460+
# Print Function Modes
461+
##
462+
463+
@final
464+
class DefaultPrintFunctionMode: ...
465+
466+
@final
467+
class CSVPrintFunctionMode: ...
468+
469+
_PrintFunctionMode: TypeAlias = DefaultPrintFunctionMode | CSVPrintFunctionMode
392470

393471
##
394472
# Schedules
@@ -442,12 +520,6 @@ _Subdatatypes: TypeAlias = SubVariants | NewSort
442520
# Commands
443521
##
444522

445-
@final
446-
class SetOption:
447-
name: str
448-
value: _Expr
449-
def __init__(self, name: str, value: _Expr) -> None: ...
450-
451523
@final
452524
class Datatype:
453525
span: _Span
@@ -469,7 +541,7 @@ class Sort:
469541
def __init__(self, span: _Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
470542

471543
@final
472-
class Function:
544+
class FunctionCommand:
473545
span: _Span
474546
name: str
475547
schema: Schema
@@ -531,8 +603,12 @@ class Check:
531603
class PrintFunction:
532604
span: _Span
533605
name: str
534-
length: int
535-
def __init__(self, span: _Span, name: str, length: int) -> None: ...
606+
length: int | None
607+
filename: str | None
608+
mode: _PrintFunctionMode
609+
def __init__(
610+
self, span: _Span, name: str, length: int | None, filename: str | None, mode: _PrintFunctionMode
611+
) -> None: ...
536612

537613
@final
538614
class PrintSize:
@@ -613,11 +689,10 @@ class UnstableCombinedRuleset:
613689
def __init__(self, span: _Span, name: str, rulesets: list[str]) -> None: ...
614690

615691
_Command: TypeAlias = (
616-
SetOption
617-
| Datatype
692+
Datatype
618693
| Datatypes
619694
| Sort
620-
| Function
695+
| FunctionCommand
621696
| AddRuleset
622697
| RuleCommand
623698
| RewriteCommand

python/egglog/egraph.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -904,12 +904,9 @@ def run(
904904
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
905905
self._add_decls(schedule)
906906
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
907-
self._egraph.run_program(bindings.RunSchedule(egg_schedule))
908-
run_report = self._egraph.run_report()
909-
if not run_report:
910-
msg = "No run report saved"
911-
raise ValueError(msg)
912-
return run_report
907+
(command_output,) = self._egraph.run_program(bindings.RunSchedule(egg_schedule))
908+
assert isinstance(command_output, bindings.RunScheduleOutput)
909+
return command_output.report
913910

914911
def check_bool(self, *facts: FactLike) -> bool:
915912
"""
@@ -954,10 +951,7 @@ def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tu
954951
"""
955952
runtime_expr = to_runtime_expr(expr)
956953
extract_report = self._run_extract(runtime_expr, 0)
957-
958-
if not isinstance(extract_report, bindings.Best):
959-
msg = "No extract report saved"
960-
raise ValueError(msg) # noqa: TRY004
954+
assert isinstance(extract_report, bindings.ExtractBest)
961955
(new_typed_expr,) = self._state.exprs_from_egg(
962956
extract_report.termdag, [extract_report.term], runtime_expr.__egg_typed_expr__.tp
963957
)
@@ -973,26 +967,19 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
973967
"""
974968
runtime_expr = to_runtime_expr(expr)
975969
extract_report = self._run_extract(runtime_expr, n)
976-
if not isinstance(extract_report, bindings.Variants):
977-
msg = "Wrong extract report type"
978-
raise ValueError(msg) # noqa: TRY004
970+
assert isinstance(extract_report, bindings.ExtractVariants)
979971
new_exprs = self._state.exprs_from_egg(
980972
extract_report.termdag, extract_report.terms, runtime_expr.__egg_typed_expr__.tp
981973
)
982974
return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
983975

984-
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._ExtractReport:
976+
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
985977
self._add_decls(expr)
986978
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
987979
try:
988-
self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
980+
return self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))[0]
989981
except BaseException as e:
990982
raise add_note("Extracting: " + str(expr), e) # noqa: B904
991-
extract_report = self._egraph.extract_report()
992-
if not extract_report:
993-
msg = "No extract report saved"
994-
raise ValueError(msg)
995-
return extract_report
996983

997984
def push(self) -> None:
998985
"""

python/egglog/egraph_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
263263
self.egraph.run_program(bindings.Relation(span(), egg_name, schema.input))
264264
else:
265265
self.egraph.run_program(
266-
bindings.Function(
266+
bindings.FunctionCommand(
267267
span(),
268268
egg_name,
269269
self._signature_to_egg_schema(signature),

python/egglog/exp/array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
729729
_DTYPES = [float64, float32, int32, int64, DType.object]
730730

731731
converter(type, DType, lambda x: convert(np.dtype(x), DType))
732-
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
732+
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type:ignore[call-overload]
733733

734734

735735
@array_api_ruleset.register

python/tests/test_bindings.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_parse_and_run_program_exception(self):
9595

9696
def test_run_rules(self):
9797
egraph = EGraph()
98-
egraph.run_program(
98+
res = egraph.run_program(
9999
Datatype(DUMMY_SPAN, "Math", [Variant(DUMMY_SPAN, "Add", ["Math", "Math"])]),
100100
RewriteCommand(
101101
"",
@@ -109,20 +109,21 @@ def test_run_rules(self):
109109
RunSchedule(Repeat(DUMMY_SPAN, 10, Run(DUMMY_SPAN, RunConfig("")))),
110110
)
111111

112-
run_report = egraph.run_report()
113-
assert isinstance(run_report, RunReport)
112+
assert len(res) == 1
113+
assert isinstance(res[0], RunScheduleOutput)
114114

115115
def test_extract(self):
116116
# Example from extraction-cost
117117
egraph = EGraph()
118-
egraph.run_program(
118+
res = egraph.run_program(
119119
Datatype(DUMMY_SPAN, "Expr", [Variant(DUMMY_SPAN, "Num", ["i64"], cost=5)]),
120120
ActionCommand(Let(DUMMY_SPAN, "x", Call(DUMMY_SPAN, "Num", [Lit(DUMMY_SPAN, Int(1))]))),
121121
Extract(DUMMY_SPAN, Var(DUMMY_SPAN, "x"), Lit(DUMMY_SPAN, Int(0))),
122122
)
123123

124-
extract_report = egraph.extract_report()
125-
assert isinstance(extract_report, Best)
124+
assert len(res) == 1
125+
extract_report = res[0]
126+
assert isinstance(extract_report, ExtractBest)
126127
assert extract_report.cost == 6
127128
assert extract_report.termdag.term_to_expr(extract_report.term, DUMMY_SPAN) == Call(
128129
DUMMY_SPAN, "Num", [Lit(DUMMY_SPAN, Int(1))]

0 commit comments

Comments
 (0)