Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 231 additions & 14 deletions Cargo.lock

Large diffs are not rendered by default.

27 changes: 11 additions & 16 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.27", features = ["extension-module", "num-bigint", "num-rational"] }
num-bigint = "*"
num-rational = "*"
# egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main", default-features = false }
# egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
# egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "let-bindings-again", default-features = false }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "let-bindings-again" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "let-bindings-again" }
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "update-egglog", default-features = false }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "let-bindings-again" }
egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main", default-features = false }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "main", default-features = false }
egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
egglog-reports = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] }
serde_json = "1"
pyo3-log = { git = "https://github.com/a1phyr/pyo3-log.git", branch = "pyo3_0.27" }
pyo3-log = "*"
log = "0.4"
lalrpop-util = { version = "0.22", features = ["lexer"] }
ordered-float = "5"
Expand All @@ -33,13 +31,10 @@ base64 = "0.22.1"

# Use patched version of egglog in experimental
[patch.'https://github.com/egraphs-good/egglog']
# egglog = { git = "https://github.com/egraphs-good//egglog.git", branch = "main" }
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "let-bindings-again" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "let-bindings-again" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "let-bindings-again" }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "let-bindings-again" }
# egglog = { path = "../egg-smol" }
# egglog = { git = "https://github.com/egraphs-good//egglog.git", rev = "5542549" }
# egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
# egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
# egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
# egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }

# enable debug symbols for easier profiling
[profile.release]
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ _This project uses semantic versioning_

## UNRELEASED

- Add support for setting report level with `egraph.set_report_level` [#375](https://github.com/egraphs-good/egglog-python/pull/375)
- Make docs builds fail on notebook execution errors and fix all doc issues [#369](https://github.com/egraphs-good/egglog-python/pull/369)
- Add WIP `egglog.exp.any_expr` code for tracing arbitrary expressions with Python fallback [#366](https://github.com/egraphs-good/egglog-python/pull/366)
- BREAKING: Remove support for Python 3.11 now that pyo3 has dropped support.
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from . import config, ipython_magic # noqa: F401
from .bindings import EggSmolError # noqa: F401
from .bindings import EggSmolError, StageInfo, TimeOnly, WithPlan # noqa: F401
from .builtins import * # noqa: UP029
from .conversion import *
from .deconstruct import *
Expand Down
107 changes: 106 additions & 1 deletion python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,21 @@ __all__ = [
"Float",
"Function",
"FunctionCommand",
"FusedIntersect",
"IdentSort",
"Include",
"Input",
"Int",
"Intersect",
"IterationReport",
"Let",
"Lit",
"NewSort",
"Output",
"OverallStatistics",
"Panic",
"PanicSpan",
"Plan",
"Pop",
"PrintAllFunctionsSize",
"PrintFunction",
Expand All @@ -58,26 +62,33 @@ __all__ = [
"RewriteCommand",
"Rule",
"RuleCommand",
"RuleReport",
"RuleSetReport",
"Run",
"RunConfig",
"RunReport",
"RunSchedule",
"RunScheduleOutput",
"RustSpan",
"Saturate",
"Scan",
"Schema",
"Sequence",
"SerializedEGraph",
"Set",
"SingleScan",
"Sort",
"SrcFile",
"StageInfo",
"StageStats",
"String",
"SubVariants",
"Subsume",
"TermApp",
"TermDag",
"TermLit",
"TermVar",
"TimeOnly",
"Union",
"Unit",
"UnstableCombinedRuleset",
Expand All @@ -87,6 +98,7 @@ __all__ = [
"Value",
"Var",
"Variant",
"WithPlan",
]

@final
Expand Down Expand Up @@ -118,6 +130,7 @@ class EGraph:
max_calls_per_function: int | None = None,
include_temporary_functions: bool = False,
) -> SerializedEGraph: ...
def set_report_level(self, level: _ReportLevel) -> None: ...
def lookup_function(self, name: str, key: list[Value]) -> Value | None: ...
def eval_expr(self, expr: _Expr) -> tuple[str, Value]: ...
def value_to_i64(self, v: Value) -> int: ...
Expand Down Expand Up @@ -389,12 +402,99 @@ class IdentSort:
@final
class UserDefinedCommandOutput: ...

@final
class SingleScan:
atom: str
column: tuple[str, int]

def __new__(cls, atom: str, column: tuple[str, int]) -> SingleScan: ...

@final
class Scan:
atom: str
columns: list[tuple[str, int]]

def __new__(cls, atom: str, columns: list[tuple[str, int]]) -> Scan: ...

@final
class StageStats:
num_candidates: int
num_succeeded: int

def __new__(cls, num_candidates: int, num_succeeded: int) -> StageStats: ...

@final
class TimeOnly:
def __new__(cls) -> TimeOnly: ...

@final
class WithPlan:
def __new__(cls) -> WithPlan: ...

@final
class StageInfo:
def __new__(cls) -> StageInfo: ...

_ReportLevel: TypeAlias = TimeOnly | WithPlan | StageInfo

@final
class Intersect:
scans: list[SingleScan]

def __new__(cls, scans: list[SingleScan]) -> Intersect: ...

@final
class FusedIntersect:
cover: Scan
to_intersect: list[Scan]

def __new__(cls, cover: Scan, to_intersect: list[Scan]) -> FusedIntersect: ...

_Stage: TypeAlias = Intersect | FusedIntersect

@final
class Plan:
stages: list[tuple[_Stage, StageStats | None, list[int]]]

def __new__(cls, stages: list[tuple[_Stage, StageStats | None, list[int]]]) -> Plan: ...

@final
class RuleReport:
plan: Plan | None
search_and_apply_time: timedelta
num_matches: int

def __new__(cls, plan: Plan | None, search_and_apply_time: timedelta, num_matches: int) -> RuleReport: ...

@final
class RuleSetReport:
changed: bool
rule_reports: dict[str, list[RuleReport]]
search_and_apply_time: timedelta
merge_time: timedelta

def __new__(
cls,
changed: bool,
rule_reports: dict[str, list[RuleReport]],
search_and_apply_time: timedelta,
merge_time: timedelta,
) -> RuleSetReport: ...

@final
class IterationReport:
rule_set_report: RuleSetReport
rebuild_time: timedelta

def __new__(cls, rule_set_report: RuleSetReport, rebuild_time: timedelta) -> IterationReport: ...

@final
class Function:
name: str

@final
class RunReport:
iterations: list[IterationReport]
updated: bool
search_and_apply_time_per_rule: dict[str, timedelta]
num_matches_per_rule: dict[str, int]
Expand All @@ -404,6 +504,7 @@ class RunReport:

def __new__(
cls,
iterations: list[IterationReport],
updated: bool,
search_and_apply_time_per_rule: dict[str, timedelta],
num_matches_per_rule: dict[str, int],
Expand Down Expand Up @@ -688,7 +789,11 @@ class Constructor:
def __new__(cls, span: _Span, name: str, schema: Schema, cost: int | None, unextractable: bool) -> Constructor: ...

@final
class PrintOverallStatistics: ...
class PrintOverallStatistics:
span: _Span
file: str | None

def __new__(cls, span: _Span, file: str | None) -> PrintOverallStatistics: ...

@final
class UserDefined:
Expand Down
9 changes: 8 additions & 1 deletion python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"vars_",
]


T = TypeVar("T")
P = ParamSpec("P")
EXPR_TYPE = TypeVar("EXPR_TYPE", bound="type[Expr]")
Expand Down Expand Up @@ -869,6 +870,12 @@ def _add_decls(self, *decls: DeclerationsLike) -> None:
for d in decls:
self._state.__egg_decls__ |= d

def set_report_level(self, level: bindings._ReportLevel) -> None:
"""
Set the level of detail recorded in subsequent run reports.
"""
self._egraph.set_report_level(level)

@property
def as_egglog_string(self) -> str:
"""
Expand Down Expand Up @@ -948,7 +955,7 @@ def stats(self) -> bindings.RunReport:
"""
Returns the overall run report for the egraph.
"""
(output,) = self._egraph.run_program(bindings.PrintOverallStatistics())
(output,) = self._egraph.run_program(bindings.PrintOverallStatistics(span(1), None))
assert isinstance(output, bindings.OverallStatistics)
return output.report

Expand Down
Loading
Loading