@@ -252,14 +252,92 @@ def generate_trampolines(
252252 _add_file_to_sources_txt_file (dst_path )
253253
254254
255+ def generate_linalg (mod_path ):
256+ import mlir_utils .dialects
257+ import mlir .dialects
258+ from mlir_utils ._configuration .configuration import _add_file_to_sources_txt_file
259+ from mlir .dialects .linalg import DefinedOpCallable , OperandKind
260+ from mlir_utils .util import (
261+ get_result_or_results ,
262+ maybe_cast ,
263+ region_op ,
264+ get_user_code_loc ,
265+ )
266+
267+ linalg_modu = __import__ (mod_path , fromlist = ["*" ])
268+
269+ dst_path = Path (mlir_utils .dialects .__path__ [0 ])
270+ dst_path = dst_path / f"linalg.py"
271+ with open (dst_path , "w" ) as f :
272+ functions = []
273+ linalg_import = ast .ImportFrom (
274+ module = mlir .dialects .__name__ ,
275+ names = [ast .Name ("linalg" )],
276+ level = 0 ,
277+ )
278+ ods_imports = ast .ImportFrom (
279+ module = mlir_utils .util .__name__ ,
280+ names = [ast .alias (f .__name__ ) for f in [maybe_cast , get_user_code_loc ]],
281+ level = 0 ,
282+ )
283+ _keywords = [
284+ ast .keyword ("loc" , ast .Name ("loc" )),
285+ ast .keyword ("ip" , ast .Name ("ip" )),
286+ ]
287+ for name , op_callable in inspect .getmembers (
288+ linalg_modu , lambda x : isinstance (x , DefinedOpCallable )
289+ ):
290+ inputs = [
291+ ast .Name (o )
292+ for o , def_ in op_callable .op_def .registered_operands .items ()
293+ if def_ .kind == OperandKind .INPUT_TENSOR
294+ ]
295+ outputs = [
296+ ast .Name (o )
297+ for o , def_ in op_callable .op_def .registered_operands .items ()
298+ if def_ .kind == OperandKind .OUTPUT_TENSOR
299+ ]
300+
301+ keywords = _keywords + [ast .keyword ("outs" , ast .List (outputs ))]
302+ # body = [ast.Str(op_callable.op_def.metadata.doc)]
303+ body = [ast .parse (f"if loc is None: loc = { get_user_code_loc .__name__ } ()" )]
304+ body += [
305+ ast .parse (
306+ f"return { maybe_cast .__name__ } ({ ast .unparse (ast_call ('linalg.' + name , inputs , keywords ))} )"
307+ ).body [0 ]
308+ ]
309+ n = ast .FunctionDef (
310+ name = op_callable .op_name ,
311+ args = ast .arguments (
312+ posonlyargs = [],
313+ args = inputs + outputs ,
314+ defaults = [],
315+ kwonlyargs = [ast .Name ("loc" ), ast .Name ("ip" )],
316+ kw_defaults = [ast .Name ("None" ), ast .Name ("None" )],
317+ ),
318+ body = body ,
319+ decorator_list = [],
320+ )
321+ ast .fix_missing_locations (n )
322+ functions .append (n )
323+
324+ new_mod = ast .Module ([linalg_import , ods_imports ] + functions , [])
325+ new_src = ast .unparse (new_mod )
326+ generated = black .format_file_contents (new_src , fast = False , mode = black .Mode ())
327+ f .write (generated )
328+ _add_file_to_sources_txt_file (dst_path )
329+
330+
255331def generate_all_upstream_trampolines ():
256332 import mlir .dialects
257333 import mlir_utils .dialects
258334
259335 for mod in pkgutil .iter_modules (mlir .dialects .__path__ ):
260- if not mod .name .startswith ("_" ):
336+ if not mod .name .startswith ("_" ) and mod . name != "linalg" :
261337 generate_trampolines (
262338 f"mlir.dialects.{ mod .name } " ,
263339 Path (mlir_utils .dialects .__path__ [0 ]),
264340 mod .name ,
265341 )
342+ elif mod .name == "linalg" :
343+ generate_linalg ("mlir.dialects.linalg" )
0 commit comments