Skip to content

Commit d323f33

Browse files
Merge pull request #1 from metadsl/move-rust
Move types into rust
2 parents 71093ea + eac7f1d commit d323f33

File tree

13 files changed

+540
-358
lines changed

13 files changed

+540
-358
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8
22
[flake8]
33
max-line-length = 88
4-
extend-ignore = E203,E501,F405,F403,E302
4+
extend-ignore = E203,E501,F405,F403,E302,E305,F821

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
python-version: "3.10"
3939
- uses: actions/checkout@v2
4040
- run: pip install -e . mypy
41-
- run: python -m mypy.stubtest egg_smol.bindings
41+
- run: python -m mypy.stubtest egg_smol.bindings --allowlist stubtest_allow
4242
docs:
4343
runs-on: ubuntu-latest
4444
steps:

.pre-commit-config.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,3 @@ repos:
1818
rev: 5.0.4
1919
hooks:
2020
- id: flake8
21-
- repo: https://github.com/pre-commit/mirrors-mypy
22-
rev: "v0.982"
23-
hooks:
24-
- id: mypy

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ crate-type = ["cdylib"]
1010

1111
[dependencies]
1212
pyo3 = { version = "0.17.1", features = ["extension-module"] }
13-
egg-smol = { git = "https://github.com/saulshanabrook/egg-smol", branch = "public-api" }
13+
egg-smol = { git = "https://github.com/mwillsey/egg-smol", ref = "9a45bdee9b5395ab821318a033c6ac7fa91f91b9" }
1414

1515
[package.metadata.maturin]
1616
name = "egg_smol.bindings"

docs/explanation/compared_to_rust.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ One way to run this in Python is to parse the text and run it similar to how the
5252
egg CLI works:
5353

5454
```{code-cell} python
55+
from egg_smol.bindings import *
56+
5557
eqsat_basic = """(datatype Math
5658
(Num i64)
5759
(Var String)
@@ -78,8 +80,6 @@ eqsat_basic = """(datatype Math
7880
(run 10)
7981
(check (= expr1 expr2))"""
8082
81-
from egg_smol.bindings import EGraph
82-
8383
egraph = EGraph()
8484
egraph.parse_and_run_program(eqsat_basic)
8585
```
@@ -90,8 +90,6 @@ However, this isn't the most friendly for Python users. Instead, we can use the
9090
low level APIs that mirror the rust APIs to build the same egraph:
9191

9292
```{code-cell} python
93-
from egg_smol.bindings_py import *
94-
9593
egraph = EGraph()
9694
egraph.declare_sort("Math")
9795
egraph.declare_constructor(Variant("Num", ["i64"]), "Math")

python/egg_smol/bindings.pyi

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,147 @@ from typing import Optional
33

44
from typing_extensions import final
55

6-
from .bindings_py import Expr, Fact_, FunctionDecl, Rewrite, Variant
7-
86
@final
97
class EGraph:
108
def parse_and_run_program(self, input: str) -> list[str]: ...
119
def declare_constructor(self, variant: Variant, sort: str) -> None: ...
1210
def declare_sort(self, name: str) -> None: ...
1311
def declare_function(self, decl: FunctionDecl) -> None: ...
14-
def define(self, name: str, expr: Expr, cost: Optional[int] = None) -> None: ...
12+
def define(self, name: str, expr: _Expr, cost: Optional[int] = None) -> None: ...
1513
def add_rewrite(self, rewrite: Rewrite) -> str: ...
1614
def run_rules(self, limit: int) -> tuple[timedelta, timedelta, timedelta]: ...
17-
def check_fact(self, fact: Fact_) -> None: ...
15+
def check_fact(self, fact: _Fact) -> None: ...
1816

1917
@final
2018
class EggSmolError(Exception):
2119
context: str
20+
21+
@final
22+
class Int:
23+
def __init__(self, value: int) -> None: ...
24+
value: int
25+
26+
@final
27+
class String:
28+
def __init__(self, value: str) -> None: ...
29+
value: str
30+
31+
@final
32+
class Unit:
33+
def __init__(self) -> None: ...
34+
35+
_Literal = Int | String | Unit
36+
37+
@final
38+
class Lit:
39+
def __init__(self, value: _Literal) -> None: ...
40+
value: _Literal
41+
42+
@final
43+
class Var:
44+
def __init__(self, name: str) -> None: ...
45+
name: str
46+
47+
@final
48+
class Call:
49+
def __init__(self, name: str, args: list[_Expr]) -> None: ...
50+
name: str
51+
args: list[_Expr]
52+
53+
_Expr = Lit | Var | Call
54+
55+
@final
56+
class Eq:
57+
def __init__(self, exprs: list[_Expr]) -> None: ...
58+
exprs: list[_Expr]
59+
60+
@final
61+
class Fact:
62+
def __init__(self, expr: _Expr) -> None: ...
63+
expr: _Expr
64+
65+
_Fact = Fact | Eq
66+
67+
@final
68+
class Define:
69+
def __init__(self, lhs: str, rhs: _Expr) -> None: ...
70+
lhs: str
71+
rhs: _Expr
72+
73+
@final
74+
class Set:
75+
def __init__(self, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
76+
lhs: str
77+
args: list[_Expr]
78+
rhs: _Expr
79+
80+
@final
81+
class Delete:
82+
sym: str
83+
args: list[_Expr]
84+
def __init__(self, sym: str, args: list[_Expr]) -> None: ...
85+
86+
@final
87+
class Union:
88+
def __init__(self, lhs: _Expr, rhs: _Expr) -> None: ...
89+
lhs: _Expr
90+
rhs: _Expr
91+
92+
@final
93+
class Panic:
94+
def __init__(self, msg: str) -> None: ...
95+
msg: str
96+
97+
@final
98+
class Expr_:
99+
def __init__(self, expr: _Expr) -> None: ...
100+
expr: _Expr
101+
102+
_Action = Define | Set | Delete | Union | Panic | Expr_
103+
104+
@final
105+
class FunctionDecl:
106+
name: str
107+
schema: Schema
108+
default: Optional[_Expr]
109+
merge: Optional[_Expr]
110+
cost: Optional[int]
111+
def __init__(
112+
self,
113+
name: str,
114+
schema: Schema,
115+
default: Optional[_Expr],
116+
merge: Optional[_Expr],
117+
cost: Optional[int] = None,
118+
) -> None: ...
119+
120+
@final
121+
class Variant:
122+
def __init__(
123+
self, name: str, types: list[str], cost: Optional[int] = None
124+
) -> None: ...
125+
name: str
126+
types: list[str]
127+
cost: Optional[int]
128+
129+
@final
130+
class Schema:
131+
input: list[str]
132+
output: str
133+
def __init__(self, input: list[str], output: str) -> None: ...
134+
135+
@final
136+
class Rule:
137+
head: list[_Action]
138+
body: list[_Fact]
139+
def __init__(self, head: list[_Action], body: list[_Fact]) -> None: ...
140+
141+
@final
142+
class Rewrite:
143+
lhs: _Expr
144+
rhs: _Expr
145+
conditions: list[_Fact]
146+
147+
def __init__(
148+
self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []
149+
) -> None: ...

python/egg_smol/bindings_py.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

python/tests/test.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import datetime
22

33
import pytest
4-
from egg_smol.bindings import EggSmolError, EGraph
5-
from egg_smol.bindings_py import *
4+
from egg_smol.bindings import *
65

76

87
class TestEGraph:
@@ -334,3 +333,29 @@ def test_check_fact(self):
334333
]
335334
)
336335
)
336+
337+
# def test_extract(self):
338+
# # Example from extraction-cost
339+
# egraph = EGraph()
340+
# egraph.declare_sort("Expr")
341+
# egraph.declare_constructor(Variant("Num", ["i64"], cost=5), "Expr")
342+
343+
# egraph.define("x", Call("Num", [Lit(Int(1))]), cost=10)
344+
# egraph.define("y", Call("Num", [Lit(Int(2))]), cost=1)
345+
346+
# assert egraph.extract("x") == Call("Num", [Lit(Int(1))])
347+
# assert egraph.extract("y") == Var("y")
348+
349+
350+
class TestVariant:
351+
def test_repr(self):
352+
assert repr(Variant("name", [])) == "Variant('name', [], None)"
353+
354+
def test_name(self):
355+
assert Variant("name", []).name == "name"
356+
357+
def test_types(self):
358+
assert Variant("name", ["a", "b"]).types == ["a", "b"]
359+
360+
def test_cost(self):
361+
assert Variant("name", [], cost=1).cost == 1

0 commit comments

Comments
 (0)