Skip to content

Commit 4383bcc

Browse files
Expose define command
1 parent 17ade5d commit 4383bcc

File tree

10 files changed

+295
-22
lines changed

10 files changed

+295
-22
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8
22
[flake8]
33
max-line-length = 88
4-
extend-ignore = E203,E501
4+
extend-ignore = E203,E501,F405,F403

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@ crate-type = ["cdylib"]
1010

1111
[dependencies]
1212
pyo3 = { version = "0.17.1", features = ["extension-module"] }
13-
egg-smol = { git = "https://github.com/mwillsey/egg-smol", rev = "94c173b9d6152c18deca307bb2e8d2b8c412c16b" }
13+
egg-smol = { git = "https://github.com/saulshanabrook/egg-smol", branch = "public-api" }
1414

1515
[package.metadata.maturin]
1616
name = "egg_smol.bindings"
17-
18-
# Patch egg-smol to use local version
19-
# [patch.crates-io]
20-
# egg-smol = { path = "../egg-smol" }

python/egg_smol/bindings.pyi

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Optional
2+
3+
from typing_extensions import final
4+
5+
from .bindings_py import Expr, FunctionDecl, Variant
6+
7+
@final
8+
class EGraph:
9+
def ___init__(self) -> None: ...
10+
def parse_and_run_program(self, program: str) -> list[str]: ...
11+
def declare_constructor(self, variant: Variant, sort: str) -> None: ...
12+
def declare_sort(self, name: str) -> None: ...
13+
def declare_function(self, function: FunctionDecl) -> None: ...
14+
def define(self, name: str, expr: Expr, cost: Optional[int] = None) -> None: ...
15+
16+
@final
17+
class EggSmolError(Exception):
18+
pass

python/egg_smol/bindings_py.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,63 @@
22
from __future__ import annotations
33

44
from dataclasses import dataclass
5-
from typing import Optional
5+
from typing import Optional, Union
66

77

8-
@dataclass
8+
@dataclass(frozen=True)
99
class Variant:
1010
name: str
1111
types: list[str]
1212
cost: Optional[int] = None
13+
14+
15+
@dataclass(frozen=True)
16+
class FunctionDecl:
17+
name: str
18+
schema: Schema
19+
default: Optional[Expr] = None
20+
merge: Optional[Expr] = None
21+
cost: Optional[int] = None
22+
23+
24+
@dataclass(frozen=True)
25+
class Schema:
26+
input: list[str]
27+
output: str
28+
29+
30+
@dataclass(frozen=True)
31+
class Lit:
32+
value: Literal
33+
34+
35+
@dataclass(frozen=True)
36+
class Var:
37+
name: str
38+
39+
40+
@dataclass(frozen=True)
41+
class Call:
42+
name: str
43+
args: list[Expr]
44+
45+
46+
Expr = Union[Lit, Var, Call]
47+
48+
49+
@dataclass(frozen=True)
50+
class Int:
51+
value: int
52+
53+
54+
@dataclass(frozen=True)
55+
class String:
56+
value: str
57+
58+
59+
@dataclass(frozen=True)
60+
class Unit:
61+
pass
62+
63+
64+
Literal = Union[Int, String, Unit]

python/egg_smol/py.typed

Whitespace-only changes.

python/tests/test.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from egg_smol.bindings import EggSmolError, EGraph
3-
from egg_smol.bindings_py import Variant
3+
from egg_smol.bindings_py import *
44

55

66
class TestEGraph:
@@ -27,3 +27,44 @@ def test_datatype(self):
2727
egraph.declare_constructor(Variant("Var", ["String"]), "Math")
2828
egraph.declare_constructor(Variant("Add", ["Math", "Math"]), "Math")
2929
egraph.declare_constructor(Variant("Mul", ["Math", "Math"]), "Math")
30+
31+
def test_define(self):
32+
egraph = EGraph()
33+
egraph.declare_sort("Math")
34+
egraph.declare_constructor(Variant("Num", ["i64"]), "Math")
35+
egraph.declare_constructor(Variant("Var", ["String"]), "Math")
36+
egraph.declare_constructor(Variant("Add", ["Math", "Math"]), "Math")
37+
egraph.declare_constructor(Variant("Mul", ["Math", "Math"]), "Math")
38+
39+
# (define expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))
40+
egraph.define(
41+
"expr1",
42+
Call(
43+
"Mul",
44+
[
45+
Call(
46+
"Num",
47+
[
48+
Lit(Int(2)),
49+
],
50+
),
51+
Call(
52+
"Add",
53+
[
54+
Call(
55+
"Var",
56+
[
57+
Lit(String("x")),
58+
],
59+
),
60+
Call(
61+
"Num",
62+
[
63+
Lit(Int(3)),
64+
],
65+
),
66+
],
67+
),
68+
],
69+
),
70+
)

src/conversions.rs

Lines changed: 158 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,33 @@
55
// to represent the input types
66
use pyo3::prelude::*;
77

8+
// Execute the block and wrap the error in a type error
9+
fn wrap_error<T>(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult<T>) -> PyResult<T> {
10+
block().map_err(|e| {
11+
PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
12+
"Error converting {} to {}: {}",
13+
obj, tp, e
14+
))
15+
})
16+
}
17+
18+
// Wrapped version of Variant
819
pub struct WrappedVariant(egg_smol::ast::Variant);
920

1021
impl FromPyObject<'_> for WrappedVariant {
1122
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
12-
Ok(WrappedVariant(egg_smol::ast::Variant {
13-
name: obj.getattr("name")?.extract::<String>()?.into(),
14-
cost: obj.getattr("cost")?.extract()?,
15-
types: obj
16-
.getattr("types")?
17-
.extract::<Vec<String>>()?
18-
.into_iter()
19-
.map(|x| x.into())
20-
.collect(),
21-
}))
23+
wrap_error("Variant", obj, || {
24+
Ok(WrappedVariant(egg_smol::ast::Variant {
25+
name: obj.getattr("name")?.extract::<String>()?.into(),
26+
cost: obj.getattr("cost")?.extract()?,
27+
types: obj
28+
.getattr("types")?
29+
.extract::<Vec<String>>()?
30+
.into_iter()
31+
.map(|x| x.into())
32+
.collect(),
33+
}))
34+
})
2235
}
2336
}
2437

@@ -27,3 +40,138 @@ impl From<WrappedVariant> for egg_smol::ast::Variant {
2740
other.0
2841
}
2942
}
43+
44+
// Wrapped version of FunctionDecl
45+
pub struct WrappedFunctionDecl(egg_smol::ast::FunctionDecl);
46+
impl FromPyObject<'_> for WrappedFunctionDecl {
47+
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
48+
wrap_error("FunctionDecl", obj, || {
49+
Ok(WrappedFunctionDecl(egg_smol::ast::FunctionDecl {
50+
name: obj.getattr("name")?.extract::<String>()?.into(),
51+
schema: obj.getattr("schema")?.extract::<WrappedSchema>()?.into(),
52+
default: obj
53+
.getattr("default")?
54+
.extract::<Option<WrappedExpr>>()?
55+
.map(|x| x.into()),
56+
merge: obj
57+
.getattr("merge")?
58+
.extract::<Option<WrappedExpr>>()?
59+
.map(|x| x.into()),
60+
cost: obj.getattr("cost")?.extract()?,
61+
}))
62+
})
63+
}
64+
}
65+
66+
impl From<WrappedFunctionDecl> for egg_smol::ast::FunctionDecl {
67+
fn from(other: WrappedFunctionDecl) -> Self {
68+
other.0
69+
}
70+
}
71+
72+
// Wrapped version of Schema
73+
pub struct WrappedSchema(egg_smol::ast::Schema);
74+
75+
impl FromPyObject<'_> for WrappedSchema {
76+
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
77+
wrap_error("Schema", obj, || {
78+
Ok(WrappedSchema(egg_smol::ast::Schema {
79+
input: obj
80+
.getattr("input")?
81+
.extract::<Vec<String>>()?
82+
.into_iter()
83+
.map(|x| x.into())
84+
.collect(),
85+
output: obj.getattr("output")?.extract::<String>()?.into(),
86+
}))
87+
})
88+
}
89+
}
90+
91+
impl From<WrappedSchema> for egg_smol::ast::Schema {
92+
fn from(other: WrappedSchema) -> Self {
93+
other.0
94+
}
95+
}
96+
97+
// Wrapped version of Expr
98+
pub struct WrappedExpr(egg_smol::ast::Expr);
99+
100+
impl FromPyObject<'_> for WrappedExpr {
101+
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
102+
wrap_error("Expr", obj, ||
103+
// Try extracting into each type of expression, and return the first one that works
104+
extract_expr_lit(obj)
105+
.or_else(|_| extract_expr_call(obj))
106+
.or_else(|_| extract_expr_var(obj))
107+
.map(WrappedExpr))
108+
}
109+
}
110+
111+
fn extract_expr_lit(obj: &PyAny) -> PyResult<egg_smol::ast::Expr> {
112+
Ok(egg_smol::ast::Expr::Lit(
113+
obj.getattr("value")?.extract::<WrappedLiteral>()?.into(),
114+
))
115+
}
116+
117+
fn extract_expr_var(obj: &PyAny) -> PyResult<egg_smol::ast::Expr> {
118+
Ok(egg_smol::ast::Expr::Var(
119+
obj.getattr("name")?.extract::<String>()?.into(),
120+
))
121+
}
122+
123+
fn extract_expr_call(obj: &PyAny) -> PyResult<egg_smol::ast::Expr> {
124+
Ok(egg_smol::ast::Expr::Call(
125+
obj.getattr("name")?.extract::<String>()?.into(),
126+
obj.getattr("args")?
127+
.extract::<Vec<WrappedExpr>>()?
128+
.into_iter()
129+
.map(|x| x.into())
130+
.collect(),
131+
))
132+
}
133+
134+
impl From<WrappedExpr> for egg_smol::ast::Expr {
135+
fn from(other: WrappedExpr) -> Self {
136+
other.0
137+
}
138+
}
139+
140+
// Wrapped version of Literal
141+
pub struct WrappedLiteral(egg_smol::ast::Literal);
142+
143+
impl FromPyObject<'_> for WrappedLiteral {
144+
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
145+
wrap_error("Literal", obj, || {
146+
extract_literal_int(obj)
147+
.or_else(|_| extract_literal_string(obj))
148+
.or_else(|_| extract_literal_unit(obj))
149+
.map(WrappedLiteral)
150+
})
151+
}
152+
}
153+
154+
fn extract_literal_int(obj: &PyAny) -> PyResult<egg_smol::ast::Literal> {
155+
Ok(egg_smol::ast::Literal::Int(
156+
obj.getattr("value")?.extract()?,
157+
))
158+
}
159+
160+
fn extract_literal_string(obj: &PyAny) -> PyResult<egg_smol::ast::Literal> {
161+
Ok(egg_smol::ast::Literal::String(
162+
obj.getattr("value")?.extract::<String>()?.into(),
163+
))
164+
}
165+
fn extract_literal_unit(obj: &PyAny) -> PyResult<egg_smol::ast::Literal> {
166+
if obj.is_none() {
167+
Ok(egg_smol::ast::Literal::Unit)
168+
} else {
169+
Err(pyo3::exceptions::PyTypeError::new_err("Expected None"))
170+
}
171+
}
172+
173+
impl From<WrappedLiteral> for egg_smol::ast::Literal {
174+
fn from(other: WrappedLiteral) -> Self {
175+
other.0
176+
}
177+
}

src/error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ impl EggSmolError {
1919
// Wrap the egg_smol::Error so we can automatically convert from it to the PyErr
2020
// and so return it from each function automatically
2121
// https://pyo3.rs/latest/function/error_handling.html#foreign-rust-error-types
22+
// TODO: Create classes for each of these errors
2223
pub struct WrappedError(egg_smol::Error);
2324

2425
// Convert from the WrappedError to the PyErr by creating a new Python error

src/lib.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,30 @@ impl EGraph {
2222
}
2323
}
2424

25+
/// define($self, expr, name, cost)
26+
/// --
27+
///
28+
/// Define a new named value.
29+
#[args(cost = "None")]
30+
fn define(&mut self, name: String, expr: WrappedExpr, cost: Option<usize>) -> EggResult<()> {
31+
self.egraph.define(name.into(), expr.into(), cost)?;
32+
Ok(())
33+
}
34+
35+
/// declare_function($self, decl)
36+
/// --
37+
///
38+
/// Declare a new function definition.
39+
fn declare_function(&mut self, decl: WrappedFunctionDecl) -> EggResult<()> {
40+
self.egraph.declare_function(&decl.into())?;
41+
Ok(())
42+
}
43+
2544
/// declare_sort($self, name)
2645
/// --
2746
///
2847
/// Declare a new sort with the given name.
2948
fn declare_sort(&mut self, name: &str) -> EggResult<()> {
30-
// TODO: Should the name be a symbol? If so, how should we expose that
31-
// to Python?
3249
self.egraph.declare_sort(name)?;
3350
Ok({})
3451
}

0 commit comments

Comments
 (0)