42
42
from mlir .extras .util import mlir_type_to_np_dtype
43
43
44
44
45
+ INDENT = 0
46
+ # OUTPUT_BUF = io.StringIO()
47
+ OUTPUT_BUF = sys .stdout
48
+ ATTR_ALIASES = {}
49
+
50
+
45
51
def normalize_ssa (ssa : str | Value ):
46
52
if isinstance (ssa , Value ):
47
- ssa = ssa .get_name ()
53
+ ssa = ssa .get_name (use_name_loc_as_prefix = True )
48
54
if ssa [1 ].isnumeric ():
49
55
ssa = ssa .replace ("%" , "v" )
50
56
else :
@@ -74,6 +80,8 @@ def np_array_from_shape_type(shape, dtype, splat_value=None):
74
80
75
81
76
82
def map_attr (attr ):
83
+ if attr in ATTR_ALIASES :
84
+ return ATTR_ALIASES [attr ]
77
85
attr = attr .maybe_downcast ()
78
86
if isinstance (attr , (IntegerAttr , BoolAttr , FloatAttr )):
79
87
return attr .value
@@ -130,11 +138,6 @@ def map_type(type):
130
138
return f"Type.parse('{ type } ')"
131
139
132
140
133
- indent = 0
134
- # OUTPUT_BUF = io.StringIO()
135
- OUTPUT_BUF = sys .stdout
136
-
137
-
138
141
def get_init_args (opview ):
139
142
klass = opview .__class__
140
143
while not klass .__base__ is OpView :
@@ -168,7 +171,7 @@ def underscore(word: str) -> str:
168
171
169
172
170
173
def print_opview (opview , name = None ):
171
- print (" " * indent , file = OUTPUT_BUF , end = "" )
174
+ print (" " * INDENT , file = OUTPUT_BUF , end = "" )
172
175
if len (opview .results ):
173
176
print (
174
177
", " .join ([normalize_ssa (r ) for r in opview .results ]),
@@ -249,15 +252,15 @@ def print_opview(opview, name=None):
249
252
else :
250
253
owner = f"{ op_idx_owner_name } "
251
254
print (
252
- " " * indent
255
+ " " * INDENT
253
256
+ f"{ owner } .attributes['OpIdx'] = amdgpu.OpIdxAttr.get({ attrs ['OpIdx' ].value } )" ,
254
257
file = OUTPUT_BUF ,
255
258
)
256
259
257
260
258
261
def print_func_op (func_op : func .FuncOp ):
259
262
# op.print(print_generic_op_form=True)
260
- print (" " * indent , file = OUTPUT_BUF , end = "" )
263
+ print (" " * INDENT , file = OUTPUT_BUF , end = "" )
261
264
print ("@func.func(" , file = OUTPUT_BUF , end = "" )
262
265
if len (func_op .attributes ):
263
266
attrs = []
@@ -283,7 +286,7 @@ def print_func_op(func_op: func.FuncOp):
283
286
284
287
285
288
def print_arith_constant (constop : arith .ConstantOp ):
286
- print (" " * indent , file = OUTPUT_BUF , end = "" )
289
+ print (" " * INDENT , file = OUTPUT_BUF , end = "" )
287
290
print (
288
291
f"{ normalize_ssa (constop .result )} = arith.constant({ map_attr (constop .value )} , { map_type (constop .result .type )} )" ,
289
292
file = OUTPUT_BUF ,
@@ -305,7 +308,7 @@ def print_scf_for(for_op: scf.ForOp):
305
308
)
306
309
init_args = [normalize_ssa (a ) for a in for_op .initArgs ]
307
310
print (
308
- (" " * indent )
311
+ (" " * INDENT )
309
312
+ f"for { opers_str } in scf.for_({ start } , { stop } , { step } , iter_args=[{ ', ' .join (init_args )} ]):" ,
310
313
file = OUTPUT_BUF ,
311
314
)
@@ -315,12 +318,12 @@ def print_scf_if(if_op: scf.IfOp):
315
318
assert len (if_op .results ) == 1
316
319
res = if_op .results [0 ]
317
320
res_name = normalize_ssa (res )
318
- global indent
321
+ global INDENT
319
322
320
323
def print_yield_as_return (yield_op : scf .YieldOp ):
321
324
opers = [normalize_ssa (a ) for a in yield_op .operands ]
322
325
print (
323
- (" " * indent ) + f"return { ', ' .join (opers )} " ,
326
+ (" " * INDENT ) + f"return { ', ' .join (opers )} " ,
324
327
file = OUTPUT_BUF ,
325
328
)
326
329
@@ -332,17 +335,17 @@ def print_yield_as_return(yield_op: scf.YieldOp):
332
335
def { res_name } ():\
333
336
"""
334
337
),
335
- " " * indent ,
338
+ " " * INDENT ,
336
339
),
337
340
file = OUTPUT_BUF ,
338
341
)
339
- indent += 1
342
+ INDENT += 1
340
343
for bodyop in if_op .thenRegion .blocks [0 ].operations :
341
344
if isinstance (bodyop , scf .YieldOp ):
342
345
print_yield_as_return (bodyop )
343
346
else :
344
347
bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
345
- indent -= 1
348
+ INDENT -= 1
346
349
print (
347
350
textwrap .indent (
348
351
textwrap .dedent (
@@ -351,17 +354,17 @@ def {res_name}():\
351
354
def { res_name } _else():\
352
355
""" ,
353
356
),
354
- " " * indent ,
357
+ " " * INDENT ,
355
358
),
356
359
file = OUTPUT_BUF ,
357
360
)
358
- indent += 1
361
+ INDENT += 1
359
362
for bodyop in if_op .elseRegion .blocks [0 ].operations :
360
363
if isinstance (bodyop , scf .YieldOp ):
361
364
print_yield_as_return (bodyop )
362
365
else :
363
366
bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
364
- indent -= 1
367
+ INDENT -= 1
365
368
366
369
367
370
def generic_print_walk_callback (op ):
@@ -392,16 +395,26 @@ def generic_print_walk_callback(op):
392
395
print_opview (opview )
393
396
394
397
if len (op .regions ):
395
- global indent
396
- indent += 1
398
+ global INDENT
399
+ INDENT += 1
397
400
for bodyop in op .regions [0 ].blocks [0 ].operations :
398
401
bodyop .walk (generic_print_walk_callback , WalkOrder .PRE_ORDER )
399
- indent -= 1
402
+ INDENT -= 1
400
403
return WalkResult .SKIP
401
404
402
405
return WalkResult .ADVANCE
403
406
404
407
408
+ def print_attr_alias (attr_line : str ):
409
+ print (attr_line )
410
+ alias_name , attr_str = attr_line .split (" = " , maxsplit = 1 )
411
+ assert alias_name .startswith ("#" )
412
+ alias_name = alias_name [1 :]
413
+ attr = Attribute .parse (attr_str )
414
+ print (f"{ alias_name } = { map_attr (attr )} " , file = OUTPUT_BUF )
415
+ ATTR_ALIASES [attr ] = alias_name
416
+
417
+
405
418
def main () -> None :
406
419
parser = argparse .ArgumentParser ()
407
420
parser .add_argument ("input_file" , type = Path )
0 commit comments