Skip to content

Commit f2cff77

Browse files
authored
print attr aliases (#137)
1 parent ebb0061 commit f2cff77

File tree

1 file changed

+35
-22
lines changed

1 file changed

+35
-22
lines changed

scripts/lift_mlir_to_python.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,15 @@
4242
from mlir.extras.util import mlir_type_to_np_dtype
4343

4444

45+
INDENT = 0
46+
# OUTPUT_BUF = io.StringIO()
47+
OUTPUT_BUF = sys.stdout
48+
ATTR_ALIASES = {}
49+
50+
4551
def normalize_ssa(ssa: str | Value):
4652
if isinstance(ssa, Value):
47-
ssa = ssa.get_name()
53+
ssa = ssa.get_name(use_name_loc_as_prefix=True)
4854
if ssa[1].isnumeric():
4955
ssa = ssa.replace("%", "v")
5056
else:
@@ -74,6 +80,8 @@ def np_array_from_shape_type(shape, dtype, splat_value=None):
7480

7581

7682
def map_attr(attr):
83+
if attr in ATTR_ALIASES:
84+
return ATTR_ALIASES[attr]
7785
attr = attr.maybe_downcast()
7886
if isinstance(attr, (IntegerAttr, BoolAttr, FloatAttr)):
7987
return attr.value
@@ -130,11 +138,6 @@ def map_type(type):
130138
return f"Type.parse('{type}')"
131139

132140

133-
indent = 0
134-
# OUTPUT_BUF = io.StringIO()
135-
OUTPUT_BUF = sys.stdout
136-
137-
138141
def get_init_args(opview):
139142
klass = opview.__class__
140143
while not klass.__base__ is OpView:
@@ -168,7 +171,7 @@ def underscore(word: str) -> str:
168171

169172

170173
def print_opview(opview, name=None):
171-
print(" " * indent, file=OUTPUT_BUF, end="")
174+
print(" " * INDENT, file=OUTPUT_BUF, end="")
172175
if len(opview.results):
173176
print(
174177
", ".join([normalize_ssa(r) for r in opview.results]),
@@ -249,15 +252,15 @@ def print_opview(opview, name=None):
249252
else:
250253
owner = f"{op_idx_owner_name}"
251254
print(
252-
" " * indent
255+
" " * INDENT
253256
+ f"{owner}.attributes['OpIdx'] = amdgpu.OpIdxAttr.get({attrs['OpIdx'].value})",
254257
file=OUTPUT_BUF,
255258
)
256259

257260

258261
def print_func_op(func_op: func.FuncOp):
259262
# op.print(print_generic_op_form=True)
260-
print(" " * indent, file=OUTPUT_BUF, end="")
263+
print(" " * INDENT, file=OUTPUT_BUF, end="")
261264
print("@func.func(", file=OUTPUT_BUF, end="")
262265
if len(func_op.attributes):
263266
attrs = []
@@ -283,7 +286,7 @@ def print_func_op(func_op: func.FuncOp):
283286

284287

285288
def print_arith_constant(constop: arith.ConstantOp):
286-
print(" " * indent, file=OUTPUT_BUF, end="")
289+
print(" " * INDENT, file=OUTPUT_BUF, end="")
287290
print(
288291
f"{normalize_ssa(constop.result)} = arith.constant({map_attr(constop.value)}, {map_type(constop.result.type)})",
289292
file=OUTPUT_BUF,
@@ -305,7 +308,7 @@ def print_scf_for(for_op: scf.ForOp):
305308
)
306309
init_args = [normalize_ssa(a) for a in for_op.initArgs]
307310
print(
308-
(" " * indent)
311+
(" " * INDENT)
309312
+ f"for {opers_str} in scf.for_({start}, {stop}, {step}, iter_args=[{', '.join(init_args)}]):",
310313
file=OUTPUT_BUF,
311314
)
@@ -315,12 +318,12 @@ def print_scf_if(if_op: scf.IfOp):
315318
assert len(if_op.results) == 1
316319
res = if_op.results[0]
317320
res_name = normalize_ssa(res)
318-
global indent
321+
global INDENT
319322

320323
def print_yield_as_return(yield_op: scf.YieldOp):
321324
opers = [normalize_ssa(a) for a in yield_op.operands]
322325
print(
323-
(" " * indent) + f"return {', '.join(opers)}",
326+
(" " * INDENT) + f"return {', '.join(opers)}",
324327
file=OUTPUT_BUF,
325328
)
326329

@@ -332,17 +335,17 @@ def print_yield_as_return(yield_op: scf.YieldOp):
332335
def {res_name}():\
333336
"""
334337
),
335-
" " * indent,
338+
" " * INDENT,
336339
),
337340
file=OUTPUT_BUF,
338341
)
339-
indent += 1
342+
INDENT += 1
340343
for bodyop in if_op.thenRegion.blocks[0].operations:
341344
if isinstance(bodyop, scf.YieldOp):
342345
print_yield_as_return(bodyop)
343346
else:
344347
bodyop.walk(generic_print_walk_callback, WalkOrder.PRE_ORDER)
345-
indent -= 1
348+
INDENT -= 1
346349
print(
347350
textwrap.indent(
348351
textwrap.dedent(
@@ -351,17 +354,17 @@ def {res_name}():\
351354
def {res_name}_else():\
352355
""",
353356
),
354-
" " * indent,
357+
" " * INDENT,
355358
),
356359
file=OUTPUT_BUF,
357360
)
358-
indent += 1
361+
INDENT += 1
359362
for bodyop in if_op.elseRegion.blocks[0].operations:
360363
if isinstance(bodyop, scf.YieldOp):
361364
print_yield_as_return(bodyop)
362365
else:
363366
bodyop.walk(generic_print_walk_callback, WalkOrder.PRE_ORDER)
364-
indent -= 1
367+
INDENT -= 1
365368

366369

367370
def generic_print_walk_callback(op):
@@ -392,16 +395,26 @@ def generic_print_walk_callback(op):
392395
print_opview(opview)
393396

394397
if len(op.regions):
395-
global indent
396-
indent += 1
398+
global INDENT
399+
INDENT += 1
397400
for bodyop in op.regions[0].blocks[0].operations:
398401
bodyop.walk(generic_print_walk_callback, WalkOrder.PRE_ORDER)
399-
indent -= 1
402+
INDENT -= 1
400403
return WalkResult.SKIP
401404

402405
return WalkResult.ADVANCE
403406

404407

408+
def print_attr_alias(attr_line: str):
409+
print(attr_line)
410+
alias_name, attr_str = attr_line.split(" = ", maxsplit=1)
411+
assert alias_name.startswith("#")
412+
alias_name = alias_name[1:]
413+
attr = Attribute.parse(attr_str)
414+
print(f"{alias_name} = {map_attr(attr)}", file=OUTPUT_BUF)
415+
ATTR_ALIASES[attr] = alias_name
416+
417+
405418
def main() -> None:
406419
parser = argparse.ArgumentParser()
407420
parser.add_argument("input_file", type=Path)

0 commit comments

Comments
 (0)