@@ -112,11 +112,38 @@ def visit(self, node: ast.AST) -> Result:
112112 # it will be called first before __dispatch_Call
113113 # because "Call" exists in self.registry
114114 return self .__dispatch_Call (node )
115+ elif isinstance (node , ast .With ):
116+ return self .__dispatch_With (node )
115117 return super ().visit (node )
116118
117119 def generic_visit (self , node : ast .AST ):
118120 raise DialectLoweringError (f"unsupported ast node { type (node )} :" )
119121
122+ def __dispatch_With (self , node : ast .With ):
123+ if len (node .items ) != 1 :
124+ raise DialectLoweringError ("expected exactly one item in with statement" )
125+
126+ item = node .items [0 ]
127+ if not isinstance (item .context_expr , ast .Call ):
128+ raise DialectLoweringError ("expected context expression to be a call" )
129+
130+ global_callee_result = self .get_global_nothrow (item .context_expr .func )
131+ if global_callee_result is None :
132+ raise DialectLoweringError ("cannot find call func in with context" )
133+
134+ global_callee = global_callee_result .unwrap ()
135+ if not issubclass (global_callee , Statement ):
136+ raise DialectLoweringError ("expected callee to be a statement" )
137+
138+ if (
139+ trait := global_callee .get_trait (traits .FromPythonWithSingleItem )
140+ ) is not None :
141+ return trait .lower (global_callee , self , node )
142+
143+ raise DialectLoweringError (
144+ "unsupported callee, missing FromPythonWithSingleItem trait"
145+ )
146+
120147 def __dispatch_Call (self , node : ast .Call ):
121148 # 1. try to lookup global statement object
122149 # 2. lookup local values
@@ -196,6 +223,63 @@ def __lower_Call_local(self, node: ast.Call) -> Result:
196223 return self .registry ["Call_local" ].lower_Call_local (self , callee , node )
197224 raise DialectLoweringError ("`lower_Call_local` not implemented" )
198225
226+ def default_Call_lower (self , stmt : type [Statement ], node : ast .Call ) -> Result :
227+ """Default lowering for Python call to statement.
228+
229+ This method is intended to be used by traits like `FromPythonCall` to
230+ provide a default lowering for Python calls to statements.
231+
232+ Args:
233+ stmt(type[Statement]): Statement class to construct.
234+ node(ast.Call): Python call node to lower.
235+
236+ Returns:
237+ Result: Result of lowering the Python call to statement.
238+ """
239+ args , kwargs = self .default_Call_inputs (stmt , node )
240+ return Result (self .append_stmt (stmt (* args .values (), ** kwargs )))
241+
242+ def default_Call_inputs (
243+ self , stmt : type [Statement ], node : ast .Call
244+ ) -> tuple [dict [str , SSAValue | tuple [SSAValue , ...]], dict [str , Any ]]:
245+ from kirin .decl import fields
246+ from kirin .dialects .py .data import PyAttr
247+
248+ fs = fields (stmt )
249+ stmt_std_arg_names = fs .std_args .keys ()
250+ stmt_kw_args_name = fs .kw_args .keys ()
251+ stmt_attr_prop_names = fs .attr_or_props
252+ stmt_required_names = fs .required_names
253+ stmt_group_arg_names = fs .group_arg_names
254+ args , kwargs = {}, {}
255+ for name , value in zip (stmt_std_arg_names , node .args ):
256+ self ._parse_arg (stmt_group_arg_names , args , name , value )
257+ for kw in node .keywords :
258+ if not isinstance (kw .arg , str ):
259+ raise DialectLoweringError ("Expected string for keyword argument name" )
260+
261+ arg : str = kw .arg
262+ if arg in node .args :
263+ raise DialectLoweringError (
264+ f"Keyword argument { arg } is already present in positional arguments"
265+ )
266+ elif arg in stmt_std_arg_names or arg in stmt_kw_args_name :
267+ self ._parse_arg (stmt_group_arg_names , kwargs , kw .arg , kw .value )
268+ elif arg in stmt_attr_prop_names :
269+ if not isinstance (kw .value , ast .Constant ):
270+ raise DialectLoweringError (
271+ f"Expected constant for attribute or property { arg } "
272+ )
273+ kwargs [arg ] = PyAttr (kw .value .value )
274+ else :
275+ raise DialectLoweringError (f"Unexpected keyword argument { arg } " )
276+
277+ for name in stmt_required_names :
278+ if name not in args and name not in kwargs :
279+ raise DialectLoweringError (f"Missing required argument { name } " )
280+
281+ return args , kwargs
282+
199283 def _parse_arg (
200284 self ,
201285 group_names : set [str ],
0 commit comments