Skip to content

Commit bd4a1f5

Browse files
Add datatype support
1 parent 720c36d commit bd4a1f5

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

python/egg_smol/bindings_py.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# TODO: Figure out what these modules should be called
2+
from __future__ import annotations
3+
4+
from dataclasses import dataclass
5+
from typing import Optional
6+
7+
8+
@dataclass
9+
class Variant:
10+
name: str
11+
types: list[str]
12+
cost: Optional[int] = None

python/tests/test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from egg_smol.bindings import EggSmolError, EGraph
3+
from egg_smol.bindings_py import Variant
34

45

56
class TestEGraph:
@@ -18,3 +19,11 @@ def test_parse_and_run_program_exception(self):
1819
match='Check failed: Value { tag: "i64", bits: 5 } != Value { tag: "i64", bits: 4 }',
1920
):
2021
egraph.parse_and_run_program(program)
22+
23+
def test_datatype(self):
24+
egraph = EGraph()
25+
egraph.declare_sort("Math")
26+
egraph.declare_constructor(Variant("Num", ["i64"]), "Math")
27+
egraph.declare_constructor(Variant("Var", ["String"]), "Math")
28+
egraph.declare_constructor(Variant("Add", ["Math", "Math"]), "Math")
29+
egraph.declare_constructor(Variant("Mul", ["Math", "Math"]), "Math")

src/lib.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ impl EggSmolError {
2525
}
2626
}
2727

28+
// Convert a Python Variant object into a rust variable, by getting the attributes
29+
fn get_variant(obj: &PyAny) -> PyResult<egg_smol::ast::Variant> {
30+
// TODO: Is there a way to do this more automatically?
31+
Ok(egg_smol::ast::Variant {
32+
name: obj
33+
.getattr(pyo3::intern!(obj.py(), "name"))?
34+
.extract::<String>()?
35+
.into(),
36+
cost: obj.getattr(pyo3::intern!(obj.py(), "cost"))?.extract()?,
37+
types: obj
38+
.getattr(pyo3::intern!(obj.py(), "types"))?
39+
.extract::<Vec<String>>()?
40+
.into_iter()
41+
.map(|x| x.into())
42+
.collect(),
43+
})
44+
}
45+
2846
#[pymethods]
2947
impl EGraph {
3048
#[new]
@@ -34,6 +52,28 @@ impl EGraph {
3452
}
3553
}
3654

55+
/// declare_sort($self, name)
56+
/// --
57+
///
58+
/// Declare a new sort with the given name.
59+
fn declare_sort(&mut self, name: &str) -> PyResult<()> {
60+
// TODO: Should the name be a symbol? If so, how should we expose that
61+
// to Python?
62+
self.egraph
63+
.declare_sort(name)
64+
.map_err(|e| PyErr::new::<EggSmolError, _>(e.to_string()))
65+
}
66+
67+
fn declare_constructor(
68+
&mut self,
69+
#[pyo3(from_py_with = "get_variant")] variant: egg_smol::ast::Variant,
70+
sort: &str,
71+
) -> PyResult<()> {
72+
self.egraph
73+
.declare_constructor(variant, sort)
74+
.map_err(|e| PyErr::new::<EggSmolError, _>(e.to_string()))
75+
}
76+
3777
/// parse_and_run_program($self, input)
3878
/// --
3979
///

0 commit comments

Comments
 (0)