Skip to content

Commit ea6b003

Browse files
Fix all tests
1 parent 136ab73 commit ea6b003

23 files changed

+780
-267
lines changed

Cargo.lock

Lines changed: 46 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[package]
2-
name = "egglog"
2+
name = "egglog_python"
33
version = "7.2.0"
44
edition = "2021"
55

@@ -13,10 +13,17 @@ pyo3 = { version = "0.21", features = ["extension-module"] }
1313

1414
# https://github.com/egraphs-good/egglog/compare/ceed816e9369570ffed9feeba157b19471dda70d...main
1515
# egglog = { git = "https://github.com/egraphs-good/egglog", rev = "fb4a9f114f9bb93154d6eff0dbab079b5cb4ebb6" }
16-
egglog = { path = "../egg-smol" }
16+
# egglog = { path = "../egg-smol" }
1717
# egglog = { git = "https://github.com/oflatt/egg-smol", branch = "oflatt-fast-terms" }
18-
# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "38b3014b34399cc78887ede09c845b2a5d6c7d19" }
19-
egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "5838c036623e91540831745b1574539e01c8cb23" }
18+
egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "a555b2f5e82c684442775cc1a5da94b71930113c" }
19+
egraph-serialize = { git = "https://github.com/saulshanabrook/egraph-serialize", rev = "1c205fcc6d3426800b828e9264dbadbd4a5ef6e9", features = [
20+
"serde",
21+
"graphviz",
22+
] }
23+
# egraph-serialize = { path = "../egraph-serialize", features = [
24+
# "serde",
25+
# "graphviz",
26+
# ] }
2027
serde_json = "*"
2128
pyo3-log = "0.10.0"
2229
log = "0.4.21"
@@ -26,8 +33,8 @@ uuid = { version = "1.8.0", features = ["v4"] }
2633
num-rational = "*"
2734

2835
# Use unreleased version of egraph-serialize in egglog as well
29-
[patch.crates-io]
30-
egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "5838c036623e91540831745b1574539e01c8cb23" }
36+
# [patch.crates-io]
37+
# egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "5838c036623e91540831745b1574539e01c8cb23" }
3138
# egraph-serialize = { path = "../egraph-serialize" }
3239

3340
# enable debug symbols for easier profiling

docs/reference/contributing.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ conda activate egglog-python
3232
Then install the package in editable mode with the development dependencies:
3333

3434
```bash
35-
pip install -e .[dev]
35+
maturin develop -E .[dev]
3636
```
3737

38-
Anytime you change the rust code, you can run `pip install -e .` to recompile the rust code.
38+
Anytime you change the rust code, you can run `maturin develop -E` to recompile the rust code.
3939

4040
### Running Tests
4141

docs/reference/design.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
## Lambda Functions
2+
3+
Anonymous functions are mapped to uniquely named global functions with a default rewrite of their body.
4+
5+
Constraints:
6+
7+
1. they should be uniquqe based on their types, arguments and body
8+
2. Their name is a version of that
9+
3. Their egg value is their name but simplified
10+
4. Their pretty/serialized version (for graphviz) is that version but without the types.

pyproject.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ docs = [
6262
"ablog",
6363
]
6464

65-
[tool.ruff]
65+
[tool.ruff.lint]
6666
ignore = [
6767
# Allow uppercase vars
6868
"N806",
@@ -74,7 +74,7 @@ ignore = [
7474
# Allow exec
7575
"S102",
7676
"S307",
77-
"PGH001",
77+
"S307",
7878
# allow star imports
7979
"F405",
8080
"F403",
@@ -173,12 +173,16 @@ ignore = [
173173
"D401",
174174
# Allow private member refs
175175
"SLF001",
176+
# allow blind exception to add context
177+
"BLE001",
176178
]
179+
select = ["ALL"]
180+
181+
[tool.ruff]
177182

178183
line-length = 120
179184
# Allow lines to be as long as 120.
180185
src = ["python"]
181-
select = ["ALL"]
182186
extend-exclude = ["python/tests/__snapshots__"]
183187
unsafe-fixes = true
184188

python/egglog/bindings.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SerializedEGraph:
1414
def to_dot(self) -> str: ...
1515
def to_json(self) -> str: ...
1616
def map_ops(self, map: dict[str, str]) -> None: ...
17+
def split_e_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
1718

1819
@final
1920
class PyObjectSort:
@@ -43,7 +44,6 @@ class EGraph:
4344
max_functions: int | None = None,
4445
max_calls_per_function: int | None = None,
4546
include_temporary_functions: bool = False,
46-
split_primitive_outputs: bool = False,
4747
) -> SerializedEGraph: ...
4848
def eval_py_object(self, __expr: _Expr) -> object: ...
4949
def eval_i64(self, __expr: _Expr) -> int: ...

python/egglog/builtins.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,12 @@ def __getitem__(self, index: i64Like) -> T: ...
408408
@method(egg_fn="rebuild")
409409
def rebuild(self) -> Vec[T]: ...
410410

411+
@method(egg_fn="vec-remove")
412+
def remove(self, index: i64Like) -> Vec[T]: ...
413+
414+
@method(egg_fn="vec-set")
415+
def set(self, index: i64Like, value: T) -> Vec[T]: ...
416+
411417

412418
class PyObject(Expr, builtin=True):
413419
def __init__(self, value: object) -> None: ...

python/egglog/declarations.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from typing_extensions import Self, assert_never
1414

1515
if TYPE_CHECKING:
16-
from collections.abc import Callable, Iterable
16+
from collections.abc import Callable, Iterable, Mapping
1717

1818

1919
__all__ = [
20+
"replace_typed_expr",
2021
"Declarations",
2122
"DeclerationsLike",
2223
"DelayedDeclerations",
@@ -29,6 +30,7 @@
2930
"MethodRef",
3031
"ClassMethodRef",
3132
"FunctionRef",
33+
"UnnamedFunctionRef",
3234
"ConstantRef",
3335
"ClassVariableRef",
3436
"PropertyRef",
@@ -83,17 +85,14 @@ class DelayedDeclerations:
8385

8486
@property
8587
def __egg_decls__(self) -> Declarations:
88+
thunk = self.__egg_decls_thunk__
8689
try:
87-
return self.__egg_decls_thunk__()
90+
return thunk()
8891
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
8992
# instead raise explicitly
9093
except AttributeError as err:
9194
msg = f"Cannot resolve declerations for {self}"
9295
raise RuntimeError(msg) from err
93-
# Might as well catch others too so we have more context when they raise
94-
except Exception as err: # noqa: BLE001
95-
msg = f"Cannot resolve declerations for {self}"
96-
raise RuntimeError(msg) from err
9796

9897

9998
@runtime_checkable
@@ -121,7 +120,8 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D
121120

122121
@dataclass
123122
class Declarations:
124-
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
123+
# TODO: Replace with set of unnamed function decls
124+
_functions: dict[str | UnnamedFunctionRef, FunctionDecl | RelationDecl] = field(default_factory=dict)
125125
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
126126
_classes: dict[str, ClassDecl] = field(default_factory=dict)
127127
_rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
@@ -323,9 +323,26 @@ def __str__(self) -> str:
323323
##
324324

325325

326+
@dataclass(frozen=True)
327+
class UnnamedFunctionRef:
328+
"""
329+
A reference to a function that doesn't have a name, but does have a body.
330+
"""
331+
332+
arg_types: tuple[JustTypeRef, ...]
333+
arg_names: tuple[str, ...]
334+
res: TypedExprDecl
335+
336+
@property
337+
def args(self) -> tuple[TypedExprDecl, ...]:
338+
return tuple(
339+
TypedExprDecl(tp, VarDecl(name, False)) for tp, name in zip(self.arg_types, self.arg_names, strict=True)
340+
)
341+
342+
326343
@dataclass(frozen=True)
327344
class FunctionRef:
328-
name: str
345+
name: str | UnnamedFunctionRef
329346

330347

331348
@dataclass(frozen=True)
@@ -460,6 +477,8 @@ def to_function_decl(self) -> FunctionDecl:
460477
@dataclass(frozen=True)
461478
class VarDecl:
462479
name: str
480+
# Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
481+
is_let: bool
463482

464483

465484
@dataclass(frozen=True)
@@ -566,6 +585,38 @@ def descendants(self) -> list[TypedExprDecl]:
566585
return l
567586

568587

588+
def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
589+
"""
590+
Replace all the typed expressions in the given typed expression with the replacements.
591+
"""
592+
# keep track of the traversed expressions for memoization
593+
traversed: dict[TypedExprDecl, TypedExprDecl] = {}
594+
595+
def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
596+
if typed_expr in traversed:
597+
return traversed[typed_expr]
598+
if typed_expr in replacements:
599+
res = replacements[typed_expr]
600+
else:
601+
match typed_expr.expr:
602+
case (
603+
CallDecl(callable, args, bound_tp_params)
604+
| PartialCallDecl(CallDecl(callable, args, bound_tp_params))
605+
):
606+
new_args = tuple(_inner(a) for a in args)
607+
call_decl = CallDecl(callable, new_args, bound_tp_params)
608+
res = TypedExprDecl(
609+
typed_expr.tp,
610+
call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
611+
)
612+
case _:
613+
res = typed_expr
614+
traversed[typed_expr] = res
615+
return res
616+
617+
return _inner(typed_expr)
618+
619+
569620
##
570621
# Schedules
571622
##

0 commit comments

Comments
 (0)