Skip to content

Commit 576c4eb

Browse files
committed
generate linalg correctly
1 parent b992f9c commit 576c4eb

File tree

2 files changed

+79
-25
lines changed

2 files changed

+79
-25
lines changed

mlir_utils/_configuration/generate_trampolines.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,92 @@ def generate_trampolines(
252252
_add_file_to_sources_txt_file(dst_path)
253253

254254

255+
def generate_linalg(mod_path):
256+
import mlir_utils.dialects
257+
import mlir.dialects
258+
from mlir_utils._configuration.configuration import _add_file_to_sources_txt_file
259+
from mlir.dialects.linalg import DefinedOpCallable, OperandKind
260+
from mlir_utils.util import (
261+
get_result_or_results,
262+
maybe_cast,
263+
region_op,
264+
get_user_code_loc,
265+
)
266+
267+
linalg_modu = __import__(mod_path, fromlist=["*"])
268+
269+
dst_path = Path(mlir_utils.dialects.__path__[0])
270+
dst_path = dst_path / f"linalg.py"
271+
with open(dst_path, "w") as f:
272+
functions = []
273+
linalg_import = ast.ImportFrom(
274+
module=mlir.dialects.__name__,
275+
names=[ast.Name("linalg")],
276+
level=0,
277+
)
278+
ods_imports = ast.ImportFrom(
279+
module=mlir_utils.util.__name__,
280+
names=[ast.alias(f.__name__) for f in [maybe_cast, get_user_code_loc]],
281+
level=0,
282+
)
283+
_keywords = [
284+
ast.keyword("loc", ast.Name("loc")),
285+
ast.keyword("ip", ast.Name("ip")),
286+
]
287+
for name, op_callable in inspect.getmembers(
288+
linalg_modu, lambda x: isinstance(x, DefinedOpCallable)
289+
):
290+
inputs = [
291+
ast.Name(o)
292+
for o, def_ in op_callable.op_def.registered_operands.items()
293+
if def_.kind == OperandKind.INPUT_TENSOR
294+
]
295+
outputs = [
296+
ast.Name(o)
297+
for o, def_ in op_callable.op_def.registered_operands.items()
298+
if def_.kind == OperandKind.OUTPUT_TENSOR
299+
]
300+
301+
keywords = _keywords + [ast.keyword("outs", ast.List(outputs))]
302+
# body = [ast.Str(op_callable.op_def.metadata.doc)]
303+
body = [ast.parse(f"if loc is None: loc = {get_user_code_loc.__name__}()")]
304+
body += [
305+
ast.parse(
306+
f"return {maybe_cast.__name__}({ast.unparse(ast_call('linalg.' + name, inputs, keywords))})"
307+
).body[0]
308+
]
309+
n = ast.FunctionDef(
310+
name=op_callable.op_name,
311+
args=ast.arguments(
312+
posonlyargs=[],
313+
args=inputs + outputs,
314+
defaults=[],
315+
kwonlyargs=[ast.Name("loc"), ast.Name("ip")],
316+
kw_defaults=[ast.Name("None"), ast.Name("None")],
317+
),
318+
body=body,
319+
decorator_list=[],
320+
)
321+
ast.fix_missing_locations(n)
322+
functions.append(n)
323+
324+
new_mod = ast.Module([linalg_import, ods_imports] + functions, [])
325+
new_src = ast.unparse(new_mod)
326+
generated = black.format_file_contents(new_src, fast=False, mode=black.Mode())
327+
f.write(generated)
328+
_add_file_to_sources_txt_file(dst_path)
329+
330+
255331
def generate_all_upstream_trampolines():
256332
import mlir.dialects
257333
import mlir_utils.dialects
258334

259335
for mod in pkgutil.iter_modules(mlir.dialects.__path__):
260-
if not mod.name.startswith("_"):
336+
if not mod.name.startswith("_") and mod.name != "linalg":
261337
generate_trampolines(
262338
f"mlir.dialects.{mod.name}",
263339
Path(mlir_utils.dialects.__path__[0]),
264340
mod.name,
265341
)
342+
elif mod.name == "linalg":
343+
generate_linalg("mlir.dialects.linalg")

mlir_utils/dialects/ext/scf.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -59,30 +59,6 @@ def _for(
5959

6060
for_ = region_op(_for, terminator=yield__)
6161

62-
# def range_(
63-
# start,
64-
# stop=None,
65-
# step=None,
66-
# iter_args: Optional[Sequence[Value]] = None,
67-
# *,
68-
# loc=None,
69-
# ip=None,
70-
# ):
71-
# for_op = _for(start, stop, step, iter_args, loc=loc, ip=ip)
72-
# iv = maybe_cast(for_op.induction_variable)
73-
# for_iter_args = tuple(map(maybe_cast, for_op.inner_iter_args))
74-
# results = tuple(map(maybe_cast, for_op.results_))
75-
# with InsertionPoint(for_op.body):
76-
# previous_frame = inspect.currentframe().f_back
77-
# _update_caller_vars(previous_frame, iter_args, for_iter_args)
78-
#
79-
# if len(results) > 1:
80-
# yield iv, results
81-
# elif len(results) == 1:
82-
# yield iv, results[0]
83-
# else:
84-
# yield iv
85-
8662

8763
def range_(
8864
start,

0 commit comments

Comments
 (0)