Skip to content

Commit ea7255d

Browse files
authored
check symboltable for symname (#68)
1 parent ade0e39 commit ea7255d

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

mlir/extras/dialects/ext/memref.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
from ....dialects import arith, memref
1919
from ....dialects._ods_common import get_op_result_or_op_results
2020
from ....dialects.memref import *
21-
from ....ir import DenseElementsAttr, MemRefType, ShapedType, Type, Value
21+
from ....ir import (
22+
DenseElementsAttr,
23+
MemRefType,
24+
ShapedType,
25+
Type,
26+
Value,
27+
InsertionPoint,
28+
)
2229

2330
S = ShapedType.get_dynamic_size()
2431

mlir/extras/util.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
OpView,
2727
Operation,
2828
RankedTensorType,
29+
SymbolTable,
2930
Type,
3031
Value,
3132
VectorType,
@@ -366,7 +367,25 @@ def _get_sym_name(previous_frame, check_func_call=None):
366367
ident, func_call = map(lambda x: x.strip(), src_line.split("=", maxsplit=1))
367368
if check_func_call is not None:
368369
assert re.match(check_func_call, func_call)
369-
return ident
370+
maybe_unique_sym_name = ident
371+
parent = InsertionPoint.current.block.owner
372+
for _ in range(10):
373+
try:
374+
symbol_table = SymbolTable(parent.operation)
375+
break
376+
except RuntimeError:
377+
parent = parent.parent
378+
else:
379+
raise RuntimeError("Couldn't find symbol table")
380+
381+
while maybe_unique_sym_name in symbol_table:
382+
if re.match(r".*_(\d+)$", maybe_unique_sym_name):
383+
maybe_unique_sym_name = re.sub(
384+
r"(\d+)$", lambda m: str(int(m.group(0)) + 1), maybe_unique_sym_name
385+
)
386+
else:
387+
maybe_unique_sym_name = f"{maybe_unique_sym_name}_0"
388+
return maybe_unique_sym_name
370389
except:
371390
return None
372391

0 commit comments

Comments
 (0)