Skip to content

Commit 80e962a

Browse files
authored
factor out get parent of type (#70)
1 parent d82dd3a commit 80e962a

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

mlir/extras/util.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,30 @@ def __new__(cls, name, bases, classdict, **kwargs):
360360
return new
361361

362362

363+
def find_parent_of_type(test_cb, operation=None):
364+
if isinstance(operation, OpView):
365+
operation = operation.operation
366+
if operation is None:
367+
parent = InsertionPoint.current.block.owner
368+
else:
369+
parent = operation.parent
370+
for _ in range(10):
371+
if test_cb(parent):
372+
return parent
373+
else:
374+
parent = parent.parent
375+
376+
raise RuntimeError("Couldn't matching parent of type")
377+
378+
379+
def is_symbol_table(operation):
380+
try:
381+
SymbolTable(operation)
382+
return True
383+
except RuntimeError:
384+
return False
385+
386+
363387
def _get_sym_name(previous_frame, check_func_call=None):
364388
try:
365389
with open(inspect.getfile(previous_frame)) as src_file:
@@ -369,16 +393,7 @@ def _get_sym_name(previous_frame, check_func_call=None):
369393
if check_func_call is not None:
370394
assert re.match(check_func_call, func_call)
371395
maybe_unique_sym_name = ident
372-
parent = InsertionPoint.current.block.owner
373-
for _ in range(10):
374-
try:
375-
symbol_table = SymbolTable(parent.operation)
376-
break
377-
except RuntimeError:
378-
parent = parent.parent
379-
else:
380-
raise RuntimeError("Couldn't find symbol table")
381-
396+
symbol_table = SymbolTable(find_parent_of_type(is_symbol_table).operation)
382397
while maybe_unique_sym_name in symbol_table:
383398
if re.match(r".*_(\d+)$", maybe_unique_sym_name):
384399
maybe_unique_sym_name = re.sub(

0 commit comments

Comments
 (0)