Skip to content

Commit d82dd3a

Browse files
authored
factor out prev frame get idents (#69)
1 parent ea7255d commit d82dd3a

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

mlir/extras/dialects/ext/gpu.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from ...meta import (
99
region_op,
1010
)
11-
from ...util import ModuleMeta, get_user_code_loc, make_maybe_no_args_decorator
11+
from ...util import (
12+
ModuleMeta,
13+
_get_previous_frame_idents,
14+
get_user_code_loc,
15+
make_maybe_no_args_decorator,
16+
)
1217
from ....dialects._gpu_ops_gen import _Dialect
1318
from ....dialects._ods_common import (
1419
_cext,
@@ -327,14 +332,7 @@ def __init__(self, func):
327332

328333
def __getitem__(self, item):
329334
previous_frame = inspect.currentframe().f_back
330-
var_names = [
331-
[
332-
var_name
333-
for var_name, var_val in previous_frame.f_locals.items()
334-
if var_val is arg
335-
]
336-
for arg in item
337-
]
335+
var_names = [_get_previous_frame_idents(arg, previous_frame) for arg in item]
338336
kwargs = {}
339337
for i, it in enumerate(item):
340338
assert len(var_names[i]) == 1, "expected unique kwarg"

mlir/extras/util.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ def memref_type_to_np_dtype(memref_type):
231231
return _memref_type_to_np_dtype.get(memref_type)
232232

233233

234+
def _get_previous_frame_idents(val, previous_frame):
235+
return [
236+
var_name
237+
for var_name, var_val in previous_frame.f_locals.items()
238+
if var_val is val
239+
]
240+
241+
234242
def _update_caller_vars(previous_frame, args: Sequence, replacements: Sequence):
235243
"""Update caller vars passed as args.
236244
@@ -249,14 +257,7 @@ def _update_caller_vars(previous_frame, args: Sequence, replacements: Sequence):
249257
if len(args) != len(replacements):
250258
raise ValueError(f"updates must be 1-1: {args=} {replacements=}")
251259
# find the name of the iter args in the previous frame
252-
var_names = [
253-
[
254-
var_name
255-
for var_name, var_val in previous_frame.f_locals.items()
256-
if var_val is arg
257-
]
258-
for arg in args
259-
]
260+
var_names = [_get_previous_frame_idents(arg, previous_frame) for arg in args]
260261
for i, var_names in enumerate(var_names):
261262
for var_name in var_names:
262263
previous_frame.f_locals[var_name] = replacements[i]

0 commit comments

Comments
 (0)