@@ -23,9 +23,48 @@ def ast_call(name, args=None, keywords=None):
2323 )
2424
2525
26+ class FindOperands (ast .NodeVisitor ):
27+ def __init__ (self ):
28+ self .operands = {}
29+ self .results = {}
30+
31+ def visit_Call (self , node : ast .Call ):
32+ if hasattr (node .func , "value" ) and hasattr (node .func .value , "id" ):
33+ if node .func .value .id == "operands" :
34+ if isinstance (node .args [0 ], ast .Call ):
35+ nested_call = node .args [0 ]
36+ is_optional = False
37+ elif isinstance (node .args [0 ], ast .IfExp ):
38+ nested_call = node .args [0 ].body
39+ is_optional = True
40+ else :
41+ raise RuntimeError (
42+ f"unsupported operands python code: { ast .unparse (node )} "
43+ )
44+ oper_name = inflection .underscore (nested_call .args [0 ].id ).lower ()
45+ is_variadic = "values" in nested_call .func .id
46+ type = "list[Value]" if is_variadic else "Value"
47+ if is_optional :
48+ type = f"Optional[{ type } ]"
49+ self .operands [oper_name ] = type
50+ elif node .func .value .id == "results" :
51+ if node .func .attr == "extend" :
52+ if isinstance (node .args [0 ], ast .BinOp ):
53+ # something like results.extend([operands[0].type] * 1)
54+ return
55+ else :
56+ self .results [node .args [0 ].id ] = "list[Type]"
57+ elif node .func .attr == "append" :
58+ self .results [node .args [0 ].id ] = "Type"
59+ else :
60+ raise ValueError ("unknown results object" )
61+
62+
2663# TODO(max): ops that have symboltables need to be classes but that requires some upstream support for statically
2764# identifying such ops
2865def generate_op_trampoline (op_class ):
66+ from mlir_utils .dialects .util import get_result_or_results , maybe_cast , region_op
67+
2968 _mod = ast .parse (dedent (inspect .getsource (op_class .__init__ )))
3069 init_fn = next (n for n in _mod .body if isinstance (n , ast .FunctionDef ))
3170 args = init_fn .args
@@ -41,8 +80,6 @@ def generate_op_trampoline(op_class):
4180 for k , d in zip (args .kwonlyargs , args .kw_defaults )
4281 ]
4382
44- for a in args .args + args .kwonlyargs :
45- a .annotation = None
4683 fun_name = op_class .OPERATION_NAME .split ("." )[- 1 ]
4784 if keyword .iskeyword (fun_name ):
4885 fun_name = fun_name + "_"
@@ -56,18 +93,27 @@ def generate_op_trampoline(op_class):
5693 and op_class ._ODS_REGIONS [0 ] == 1
5794 and not op_class .OPERATION_NAME .startswith ("linalg" )
5895 ):
59- decorator_list = [ast .Name (id = " region_op" , ctx = ast .Load ())]
96+ decorator_list = [ast .Name (id = region_op . __name__ , ctx = ast .Load ())]
6097 body += [ast .Return ([ast_call (op_class_name , args .args , keywords )])]
6198 else :
6299 decorator_list = []
63100 body += [
64101 ast .parse (
65- f"return get_result_or_results({ ast .unparse (ast_call (op_class_name , args .args , keywords ))} )"
102+ f"return { maybe_cast . __name__ } ( { get_result_or_results . __name__ } ({ ast .unparse (ast_call (op_class_name , args .args , keywords ))} ) )"
66103 ).body [0 ]
67104 ]
105+
106+ args = copy .deepcopy (args )
107+ oper_finder = FindOperands ()
108+ oper_finder .visit (init_fn )
109+ for a in args .args :
110+ if a .arg in oper_finder .operands :
111+ a .annotation = ast .Name (id = oper_finder .operands [a .arg ], ctx = ast .Load ())
112+ elif a .arg in oper_finder .results :
113+ a .annotation = ast .Name (id = oper_finder .results [a .arg ], ctx = ast .Load ())
68114 n = ast .FunctionDef (
69115 name = fun_name ,
70- args = copy . deepcopy ( args ) ,
116+ args = args ,
71117 body = body ,
72118 decorator_list = decorator_list ,
73119 )
@@ -77,8 +123,9 @@ def generate_op_trampoline(op_class):
77123
78124def generate_dialect_trampolines_from_module (input_module , skips : set ):
79125 import mlir_utils
80- from mlir_utils .dialects .util import get_result_or_results
126+ from mlir_utils .dialects .util import get_result_or_results , maybe_cast , region_op
81127 import mlir .dialects ._ods_common
128+ from mlir_utils ._configuration .configuration import _get_mlir_package_prefix
82129
83130 skips .update ({"_Dialect" })
84131 init_funs = {}
@@ -92,6 +139,7 @@ def generate_dialect_trampolines_from_module(input_module, skips: set):
92139 # these are extension classes and we should wrap the generated class instead
93140 obj = obj .__base__
94141 if not inspect .isfunction (obj .__init__ ):
142+ print (f"skipping { obj .__name__ } because it has no __init__" )
95143 # some builders don't have any __init__ but inherit from opview
96144 continue
97145 init_funs [obj .__name__ ] = obj
@@ -104,9 +152,17 @@ def generate_dialect_trampolines_from_module(input_module, skips: set):
104152 for op_class in sorted (init_funs .values (), key = lambda o : o .__name__ )
105153 ]
106154
155+ ir_imports = ast .ImportFrom (
156+ module = _get_mlir_package_prefix () + ".ir" ,
157+ names = [ast .alias (i ) for i in ["Value" , "Attribute" , "Type" ]],
158+ level = 0 ,
159+ )
107160 ods_imports = ast .ImportFrom (
108161 module = mlir_utils .dialects .util .__name__ ,
109- names = [ast .alias (get_result_or_results .__name__ ), ast .alias ("region_op" )],
162+ names = [
163+ ast .alias (f .__name__ )
164+ for f in [get_result_or_results , maybe_cast , region_op ]
165+ ],
110166 level = 0 ,
111167 )
112168 op_imports = ast .ImportFrom (
@@ -125,7 +181,11 @@ def generate_dialect_trampolines_from_module(input_module, skips: set):
125181 else :
126182 linalg_imports = []
127183
128- new_mod = ast .Module ([op_imports , * linalg_imports , ods_imports ] + functions , [])
184+ all = ast .parse (f"__all__ = [{ ', ' .join (repr (f .name ) for f in functions )} ]" )
185+
186+ new_mod = ast .Module (
187+ [ir_imports , op_imports , * linalg_imports , ods_imports ] + functions + [all ], []
188+ )
129189 new_src = ast .unparse (new_mod )
130190 return black .format_file_contents (new_src , fast = False , mode = black .Mode ())
131191
0 commit comments