Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 101 additions & 215 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.24.2", features = ["extension-module"] }

egglog = { git = "https://github.com/egraphs-good/egglog.git", rev = "5542549" }
egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main", default-features = false }
# egglog = { path = "../egg-smol" }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "cd51d04" }
core-relations = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "cd51d04" }
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", rev = "255b67a" }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "cli", default-features = false }
egraph-serialize = { version = "0.2.0", features = ["serde", "graphviz"] }
serde_json = "1.0.140"
pyo3-log = "0.12.4"
Expand All @@ -28,7 +28,7 @@ rayon = "1.10.0"

# Use patched version of egglog in experimental
[patch.'https://github.com/egraphs-good/egglog']
# egglog = { git = "https://github.com/egraphs-good//egglog.git", rev = "d2fa5b733de0796fb187dc5a27e570d5644aa75a" }
egglog = { git = "https://github.com/egraphs-good//egglog.git", branch = "main" }
# egglog = { path = "../egg-smol" }
# egglog = { git = "https://github.com/egraphs-good//egglog.git", rev = "5542549" }

Expand Down
117 changes: 96 additions & 21 deletions python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,30 @@ from typing_extensions import final
__all__ = [
"ActionCommand",
"AddRuleset",
"Best",
"BiRewriteCommand",
"Bool",
"CSVPrintFunctionMode",
"Call",
"Change",
"Check",
"Constructor",
"Datatype",
"Datatypes",
"DefaultPrintFunctionMode",
"Delete",
"EGraph",
"EggSmolError",
"EgglogSpan",
"Eq",
"Expr_",
"Extract",
"ExtractBest",
"ExtractVariants",
"Fact",
"Fail",
"Float",
"Function",
"FunctionCommand",
"IdentSort",
"Include",
"Input",
Expand All @@ -35,10 +39,14 @@ __all__ = [
"Lit",
"NewSort",
"Output",
"OverallStatistics",
"Panic",
"PanicSpan",
"Pop",
"PrintAllFunctionsSize",
"PrintFunction",
"PrintFunctionOutput",
"PrintFunctionSize",
"PrintOverallStatistics",
"PrintSize",
"Push",
Expand All @@ -53,13 +61,13 @@ __all__ = [
"RunConfig",
"RunReport",
"RunSchedule",
"RunScheduleOutput",
"RustSpan",
"Saturate",
"Schema",
"Sequence",
"SerializedEGraph",
"Set",
"SetOption",
"Sort",
"SrcFile",
"String",
Expand All @@ -73,13 +81,18 @@ __all__ = [
"Unit",
"UnstableCombinedRuleset",
"UserDefined",
"UserDefinedCommandOutput",
"UserDefinedOutput",
"Var",
"Variant",
"Variants",
]

@final
class SerializedEGraph:
@property
def truncated_functions(self) -> list[str]: ...
@property
def discarded_functions(self) -> list[str]: ...
def inline_leaves(self) -> None: ...
def saturate_inline_leaves(self) -> None: ...
def to_dot(self) -> str: ...
Expand All @@ -106,9 +119,7 @@ class EGraph:
) -> None: ...
def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ...
def commands(self) -> str | None: ...
def run_program(self, *commands: _Command) -> list[str]: ...
def extract_report(self) -> _ExtractReport | None: ...
def run_report(self) -> RunReport | None: ...
def run_program(self, *commands: _Command) -> list[_CommandOutput]: ...
def serialize(
self,
root_eclasses: list[_Expr],
Expand Down Expand Up @@ -356,6 +367,13 @@ class IdentSort:
sort: str
def __init__(self, ident: str, sort: str) -> None: ...

@final
class UserDefinedCommandOutput: ...

@final
class Function:
name: str

@final
class RunReport:
updated: bool
Expand All @@ -375,20 +393,80 @@ class RunReport:
rebuild_time_per_ruleset: dict[str, timedelta],
) -> None: ...

##
# Command Outputs
##

@final
class PrintFunctionSize:
size: int
def __init__(self, size: int) -> None: ...

@final
class PrintAllFunctionsSize:
sizes: list[tuple[str, int]]
def __init__(self, sizes: list[tuple[str, int]]) -> None: ...

@final
class Variants:
class ExtractVariants:
termdag: TermDag
terms: list[_Term]
def __init__(self, termdag: TermDag, terms: list[_Term]) -> None: ...

@final
class Best:
class ExtractBest:
termdag: TermDag
cost: int
term: _Term
def __init__(self, termdag: TermDag, cost: int, term: _Term) -> None: ...

_ExtractReport: TypeAlias = Variants | Best
@final
class OverallStatistics:
report: RunReport
def __init__(self, report: RunReport) -> None: ...

@final
class RunScheduleOutput:
report: RunReport
def __init__(self, report: RunReport) -> None: ...

@final
class PrintFunctionOutput:
function: Function
termdag: TermDag
terms: list[tuple[_Term, _Term]]
mode: _PrintFunctionMode
def __init__(
self, function: Function, termdag: TermDag, terms: list[tuple[_Term, _Term]], mode: _PrintFunctionMode
) -> None: ...

@final
class UserDefinedOutput:
output: UserDefinedCommandOutput
def __init__(self, output: UserDefinedCommandOutput) -> None: ...

_CommandOutput: TypeAlias = (
PrintFunctionSize
| PrintAllFunctionsSize
| ExtractVariants
| ExtractBest
| OverallStatistics
| RunScheduleOutput
| PrintFunctionOutput
| UserDefinedOutput
)

##
# Print Function Modes
##

@final
class DefaultPrintFunctionMode: ...

@final
class CSVPrintFunctionMode: ...

_PrintFunctionMode: TypeAlias = DefaultPrintFunctionMode | CSVPrintFunctionMode

##
# Schedules
Expand Down Expand Up @@ -442,12 +520,6 @@ _Subdatatypes: TypeAlias = SubVariants | NewSort
# Commands
##

@final
class SetOption:
name: str
value: _Expr
def __init__(self, name: str, value: _Expr) -> None: ...

@final
class Datatype:
span: _Span
Expand All @@ -469,7 +541,7 @@ class Sort:
def __init__(self, span: _Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...

@final
class Function:
class FunctionCommand:
span: _Span
name: str
schema: Schema
Expand Down Expand Up @@ -531,8 +603,12 @@ class Check:
class PrintFunction:
span: _Span
name: str
length: int
def __init__(self, span: _Span, name: str, length: int) -> None: ...
length: int | None
filename: str | None
mode: _PrintFunctionMode
def __init__(
self, span: _Span, name: str, length: int | None, filename: str | None, mode: _PrintFunctionMode
) -> None: ...

@final
class PrintSize:
Expand Down Expand Up @@ -613,11 +689,10 @@ class UnstableCombinedRuleset:
def __init__(self, span: _Span, name: str, rulesets: list[str]) -> None: ...

_Command: TypeAlias = (
SetOption
| Datatype
Datatype
| Datatypes
| Sort
| Function
| FunctionCommand
| AddRuleset
| RuleCommand
| RewriteCommand
Expand Down
27 changes: 7 additions & 20 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,12 +904,9 @@ def run(
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
self._add_decls(schedule)
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
self._egraph.run_program(bindings.RunSchedule(egg_schedule))
run_report = self._egraph.run_report()
if not run_report:
msg = "No run report saved"
raise ValueError(msg)
return run_report
(command_output,) = self._egraph.run_program(bindings.RunSchedule(egg_schedule))
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report

def check_bool(self, *facts: FactLike) -> bool:
"""
Expand Down Expand Up @@ -954,10 +951,7 @@ def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tu
"""
runtime_expr = to_runtime_expr(expr)
extract_report = self._run_extract(runtime_expr, 0)

if not isinstance(extract_report, bindings.Best):
msg = "No extract report saved"
raise ValueError(msg) # noqa: TRY004
assert isinstance(extract_report, bindings.ExtractBest)
(new_typed_expr,) = self._state.exprs_from_egg(
extract_report.termdag, [extract_report.term], runtime_expr.__egg_typed_expr__.tp
)
Expand All @@ -973,26 +967,19 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
"""
runtime_expr = to_runtime_expr(expr)
extract_report = self._run_extract(runtime_expr, n)
if not isinstance(extract_report, bindings.Variants):
msg = "Wrong extract report type"
raise ValueError(msg) # noqa: TRY004
assert isinstance(extract_report, bindings.ExtractVariants)
new_exprs = self._state.exprs_from_egg(
extract_report.termdag, extract_report.terms, runtime_expr.__egg_typed_expr__.tp
)
return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]

def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._ExtractReport:
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
self._add_decls(expr)
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
try:
self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
return self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))[0]
except BaseException as e:
raise add_note("Extracting: " + str(expr), e) # noqa: B904
extract_report = self._egraph.extract_report()
if not extract_report:
msg = "No extract report saved"
raise ValueError(msg)
return extract_report

def push(self) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
self.egraph.run_program(bindings.Relation(span(), egg_name, schema.input))
else:
self.egraph.run_program(
bindings.Function(
bindings.FunctionCommand(
span(),
egg_name,
self._signature_to_egg_schema(signature),
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def __eq__(self, other: DType) -> Boolean: # type: ignore[override]
_DTYPES = [float64, float32, int32, int64, DType.object]

converter(type, DType, lambda x: convert(np.dtype(x), DType))
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload]
converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type:ignore[call-overload]


@array_api_ruleset.register
Expand Down
13 changes: 7 additions & 6 deletions python/tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_parse_and_run_program_exception(self):

def test_run_rules(self):
egraph = EGraph()
egraph.run_program(
res = egraph.run_program(
Datatype(DUMMY_SPAN, "Math", [Variant(DUMMY_SPAN, "Add", ["Math", "Math"])]),
RewriteCommand(
"",
Expand All @@ -109,20 +109,21 @@ def test_run_rules(self):
RunSchedule(Repeat(DUMMY_SPAN, 10, Run(DUMMY_SPAN, RunConfig("")))),
)

run_report = egraph.run_report()
assert isinstance(run_report, RunReport)
assert len(res) == 1
assert isinstance(res[0], RunScheduleOutput)

def test_extract(self):
# Example from extraction-cost
egraph = EGraph()
egraph.run_program(
res = egraph.run_program(
Datatype(DUMMY_SPAN, "Expr", [Variant(DUMMY_SPAN, "Num", ["i64"], cost=5)]),
ActionCommand(Let(DUMMY_SPAN, "x", Call(DUMMY_SPAN, "Num", [Lit(DUMMY_SPAN, Int(1))]))),
Extract(DUMMY_SPAN, Var(DUMMY_SPAN, "x"), Lit(DUMMY_SPAN, Int(0))),
)

extract_report = egraph.extract_report()
assert isinstance(extract_report, Best)
assert len(res) == 1
extract_report = res[0]
assert isinstance(extract_report, ExtractBest)
assert extract_report.cost == 6
assert extract_report.termdag.term_to_expr(extract_report.term, DUMMY_SPAN) == Call(
DUMMY_SPAN, "Num", [Lit(DUMMY_SPAN, Int(1))]
Expand Down
Loading
Loading