55// to represent the input types
66use pyo3:: prelude:: * ;
77
8+ // Execute the block and wrap the error in a type error
9+ fn wrap_error < T > ( tp : & str , obj : & ' _ PyAny , block : impl FnOnce ( ) -> PyResult < T > ) -> PyResult < T > {
10+ block ( ) . map_err ( |e| {
11+ PyErr :: new :: < pyo3:: exceptions:: PyTypeError , _ > ( format ! (
12+ "Error converting {} to {}: {}" ,
13+ obj, tp, e
14+ ) )
15+ } )
16+ }
17+
18+ // Wrapped version of Variant
819pub struct WrappedVariant ( egg_smol:: ast:: Variant ) ;
920
1021impl FromPyObject < ' _ > for WrappedVariant {
1122 fn extract ( obj : & ' _ PyAny ) -> PyResult < Self > {
12- Ok ( WrappedVariant ( egg_smol:: ast:: Variant {
13- name : obj. getattr ( "name" ) ?. extract :: < String > ( ) ?. into ( ) ,
14- cost : obj. getattr ( "cost" ) ?. extract ( ) ?,
15- types : obj
16- . getattr ( "types" ) ?
17- . extract :: < Vec < String > > ( ) ?
18- . into_iter ( )
19- . map ( |x| x. into ( ) )
20- . collect ( ) ,
21- } ) )
23+ wrap_error ( "Variant" , obj, || {
24+ Ok ( WrappedVariant ( egg_smol:: ast:: Variant {
25+ name : obj. getattr ( "name" ) ?. extract :: < String > ( ) ?. into ( ) ,
26+ cost : obj. getattr ( "cost" ) ?. extract ( ) ?,
27+ types : obj
28+ . getattr ( "types" ) ?
29+ . extract :: < Vec < String > > ( ) ?
30+ . into_iter ( )
31+ . map ( |x| x. into ( ) )
32+ . collect ( ) ,
33+ } ) )
34+ } )
2235 }
2336}
2437
@@ -27,3 +40,138 @@ impl From<WrappedVariant> for egg_smol::ast::Variant {
2740 other. 0
2841 }
2942}
43+
44+ // Wrapped version of FunctionDecl
45+ pub struct WrappedFunctionDecl ( egg_smol:: ast:: FunctionDecl ) ;
46+ impl FromPyObject < ' _ > for WrappedFunctionDecl {
47+ fn extract ( obj : & ' _ PyAny ) -> PyResult < Self > {
48+ wrap_error ( "FunctionDecl" , obj, || {
49+ Ok ( WrappedFunctionDecl ( egg_smol:: ast:: FunctionDecl {
50+ name : obj. getattr ( "name" ) ?. extract :: < String > ( ) ?. into ( ) ,
51+ schema : obj. getattr ( "schema" ) ?. extract :: < WrappedSchema > ( ) ?. into ( ) ,
52+ default : obj
53+ . getattr ( "default" ) ?
54+ . extract :: < Option < WrappedExpr > > ( ) ?
55+ . map ( |x| x. into ( ) ) ,
56+ merge : obj
57+ . getattr ( "merge" ) ?
58+ . extract :: < Option < WrappedExpr > > ( ) ?
59+ . map ( |x| x. into ( ) ) ,
60+ cost : obj. getattr ( "cost" ) ?. extract ( ) ?,
61+ } ) )
62+ } )
63+ }
64+ }
65+
66+ impl From < WrappedFunctionDecl > for egg_smol:: ast:: FunctionDecl {
67+ fn from ( other : WrappedFunctionDecl ) -> Self {
68+ other. 0
69+ }
70+ }
71+
72+ // Wrapped version of Schema
73+ pub struct WrappedSchema ( egg_smol:: ast:: Schema ) ;
74+
75+ impl FromPyObject < ' _ > for WrappedSchema {
76+ fn extract ( obj : & ' _ PyAny ) -> PyResult < Self > {
77+ wrap_error ( "Schema" , obj, || {
78+ Ok ( WrappedSchema ( egg_smol:: ast:: Schema {
79+ input : obj
80+ . getattr ( "input" ) ?
81+ . extract :: < Vec < String > > ( ) ?
82+ . into_iter ( )
83+ . map ( |x| x. into ( ) )
84+ . collect ( ) ,
85+ output : obj. getattr ( "output" ) ?. extract :: < String > ( ) ?. into ( ) ,
86+ } ) )
87+ } )
88+ }
89+ }
90+
91+ impl From < WrappedSchema > for egg_smol:: ast:: Schema {
92+ fn from ( other : WrappedSchema ) -> Self {
93+ other. 0
94+ }
95+ }
96+
97+ // Wrapped version of Expr
98+ pub struct WrappedExpr ( egg_smol:: ast:: Expr ) ;
99+
100+ impl FromPyObject < ' _ > for WrappedExpr {
101+ fn extract ( obj : & ' _ PyAny ) -> PyResult < Self > {
102+ wrap_error ( "Expr" , obj, ||
103+ // Try extracting into each type of expression, and return the first one that works
104+ extract_expr_lit ( obj)
105+ . or_else ( |_| extract_expr_call ( obj) )
106+ . or_else ( |_| extract_expr_var ( obj) )
107+ . map ( WrappedExpr ) )
108+ }
109+ }
110+
111+ fn extract_expr_lit ( obj : & PyAny ) -> PyResult < egg_smol:: ast:: Expr > {
112+ Ok ( egg_smol:: ast:: Expr :: Lit (
113+ obj. getattr ( "value" ) ?. extract :: < WrappedLiteral > ( ) ?. into ( ) ,
114+ ) )
115+ }
116+
117+ fn extract_expr_var ( obj : & PyAny ) -> PyResult < egg_smol:: ast:: Expr > {
118+ Ok ( egg_smol:: ast:: Expr :: Var (
119+ obj. getattr ( "name" ) ?. extract :: < String > ( ) ?. into ( ) ,
120+ ) )
121+ }
122+
123+ fn extract_expr_call ( obj : & PyAny ) -> PyResult < egg_smol:: ast:: Expr > {
124+ Ok ( egg_smol:: ast:: Expr :: Call (
125+ obj. getattr ( "name" ) ?. extract :: < String > ( ) ?. into ( ) ,
126+ obj. getattr ( "args" ) ?
127+ . extract :: < Vec < WrappedExpr > > ( ) ?
128+ . into_iter ( )
129+ . map ( |x| x. into ( ) )
130+ . collect ( ) ,
131+ ) )
132+ }
133+
134+ impl From < WrappedExpr > for egg_smol:: ast:: Expr {
135+ fn from ( other : WrappedExpr ) -> Self {
136+ other. 0
137+ }
138+ }
139+
140+ // Wrapped version of Literal
141+ pub struct WrappedLiteral ( egg_smol:: ast:: Literal ) ;
142+
143+ impl FromPyObject < ' _ > for WrappedLiteral {
144+ fn extract ( obj : & ' _ PyAny ) -> PyResult < Self > {
145+ wrap_error ( "Literal" , obj, || {
146+ extract_literal_int ( obj)
147+ . or_else ( |_| extract_literal_string ( obj) )
148+ . or_else ( |_| extract_literal_unit ( obj) )
149+ . map ( WrappedLiteral )
150+ } )
151+ }
152+ }
153+
154+ fn extract_literal_int ( obj : & PyAny ) -> PyResult < egg_smol:: ast:: Literal > {
155+ Ok ( egg_smol:: ast:: Literal :: Int (
156+ obj. getattr ( "value" ) ?. extract ( ) ?,
157+ ) )
158+ }
159+
160+ fn extract_literal_string ( obj : & PyAny ) -> PyResult < egg_smol:: ast:: Literal > {
161+ Ok ( egg_smol:: ast:: Literal :: String (
162+ obj. getattr ( "value" ) ?. extract :: < String > ( ) ?. into ( ) ,
163+ ) )
164+ }
165+ fn extract_literal_unit ( obj : & PyAny ) -> PyResult < egg_smol:: ast:: Literal > {
166+ if obj. is_none ( ) {
167+ Ok ( egg_smol:: ast:: Literal :: Unit )
168+ } else {
169+ Err ( pyo3:: exceptions:: PyTypeError :: new_err ( "Expected None" ) )
170+ }
171+ }
172+
173+ impl From < WrappedLiteral > for egg_smol:: ast:: Literal {
174+ fn from ( other : WrappedLiteral ) -> Self {
175+ other. 0
176+ }
177+ }
0 commit comments