File tree Expand file tree Collapse file tree 2 files changed +33
-24
lines changed Expand file tree Collapse file tree 2 files changed +33
-24
lines changed Original file line number Diff line number Diff line change 1+ // Create wrappers around input types so that convert from pyobjects to them
2+ // and then from them to the egg_smol types
3+ //
4+ // Converts from Python classes we define in pure python so we can use dataclasses
5+ // to represent the input types
6+ use pyo3:: prelude:: * ;
7+
8+ pub struct WrappedVariant ( egg_smol:: ast:: Variant ) ;
9+
10+ impl FromPyObject < ' _ > for WrappedVariant {
11+ fn extract ( obj : & ' _ PyAny ) -> PyResult < Self > {
12+ Ok ( WrappedVariant ( egg_smol:: ast:: Variant {
13+ name : obj. getattr ( "name" ) ?. extract :: < String > ( ) ?. into ( ) ,
14+ cost : obj. getattr ( "cost" ) ?. extract ( ) ?,
15+ types : obj
16+ . getattr ( "types" ) ?
17+ . extract :: < Vec < String > > ( ) ?
18+ . into_iter ( )
19+ . map ( |x| x. into ( ) )
20+ . collect ( ) ,
21+ } ) )
22+ }
23+ }
24+
25+ impl From < WrappedVariant > for egg_smol:: ast:: Variant {
26+ fn from ( other : WrappedVariant ) -> Self {
27+ other. 0
28+ }
29+ }
Original file line number Diff line number Diff line change 1+ mod conversions;
12mod error;
3+ use conversions:: * ;
24use error:: * ;
35use pyo3:: prelude:: * ;
46
@@ -11,24 +13,6 @@ struct EGraph {
1113 egraph : egg_smol:: EGraph ,
1214}
1315
14- // Convert a Python Variant object into a rust variable, by getting the attributes
15- fn get_variant ( obj : & PyAny ) -> PyResult < egg_smol:: ast:: Variant > {
16- // TODO: Is there a way to do this more automatically?
17- Ok ( egg_smol:: ast:: Variant {
18- name : obj
19- . getattr ( pyo3:: intern!( obj. py( ) , "name" ) ) ?
20- . extract :: < String > ( ) ?
21- . into ( ) ,
22- cost : obj. getattr ( pyo3:: intern!( obj. py( ) , "cost" ) ) ?. extract ( ) ?,
23- types : obj
24- . getattr ( pyo3:: intern!( obj. py( ) , "types" ) ) ?
25- . extract :: < Vec < String > > ( ) ?
26- . into_iter ( )
27- . map ( |x| x. into ( ) )
28- . collect ( ) ,
29- } )
30- }
31-
3216#[ pymethods]
3317impl EGraph {
3418 #[ new]
@@ -53,12 +37,8 @@ impl EGraph {
5337 /// --
5438 ///
5539 /// Declare a new datatype constructor.
56- fn declare_constructor (
57- & mut self ,
58- #[ pyo3( from_py_with = "get_variant" ) ] variant : egg_smol:: ast:: Variant ,
59- sort : & str ,
60- ) -> EggResult < ( ) > {
61- self . egraph . declare_constructor ( variant, sort) ?;
40+ fn declare_constructor ( & mut self , variant : WrappedVariant , sort : & str ) -> EggResult < ( ) > {
41+ self . egraph . declare_constructor ( variant. into ( ) , sort) ?;
6242 Ok ( { } )
6343 }
6444
You can’t perform that action at this time.
0 commit comments