Skip to content

Commit b6416ea

Browse files
authored
Bug Fix: argument names should override captured variables (#100)
* Fixes a bug that is part of captured variables logic * If `jets` was used as both a lambda argument an a global constant, then it would be replaced by its value in the AST Fixes #95
1 parent a8b3e1f commit b6416ea

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

func_adl/util_ast.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,28 @@ class _rewrite_captured_vars(ast.NodeTransformer):
322322
def __init__(self, cv: inspect.ClosureVars):
323323
self._lookup_dict: Dict[str, Any] = dict(cv.nonlocals)
324324
self._lookup_dict.update(cv.globals)
325+
self._ignore_stack = []
325326

326327
def visit_Name(self, node: ast.Name) -> Any:
328+
if self.is_arg(node.id):
329+
return node
330+
327331
if node.id in self._lookup_dict:
328332
v = self._lookup_dict[node.id]
329333
if not callable(v):
330334
return as_literal(self._lookup_dict[node.id])
331335
return node
332336

337+
def visit_Lambda(self, node: ast.Lambda) -> Any:
338+
self._ignore_stack.append([a.arg for a in node.args.args])
339+
v = super().generic_visit(node)
340+
self._ignore_stack.pop()
341+
return v
342+
343+
def is_arg(self, a_name: str) -> bool:
344+
"If the arg is on the stack, then return true"
345+
return any([a == a_name for frames in self._ignore_stack for a in frames])
346+
333347

334348
def global_getclosurevars(f: Callable) -> inspect.ClosureVars:
335349
"""Grab the closure over all passed function. Add all known global

tests/test_util_ast.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,23 @@ def test_parse_lambda_capture():
279279
assert ast.dump(r) == ast.dump(r_true)
280280

281281

282+
def test_parse_lambda_capture_ignore_local():
283+
x = 30 # NOQA type: ignore
284+
r = parse_as_ast(lambda x: x > 20)
285+
r_true = parse_as_ast(lambda y: y > 20)
286+
assert ast.dump(r) == ast.dump(r_true).replace("'y'", "'x'")
287+
288+
282289
g_cut_value = 30
283290

284291

292+
def test_parse_lambda_capture_ignore_global():
293+
x = 30 # NOQA type: ignore
294+
r = parse_as_ast(lambda g_cut_value: g_cut_value > 20)
295+
r_true = parse_as_ast(lambda y: y > 20)
296+
assert ast.dump(r) == ast.dump(r_true).replace("'y'", "'g_cut_value'")
297+
298+
285299
def test_parse_lambda_capture_nested_global():
286300
r = parse_as_ast(lambda x: (lambda y: y > g_cut_value)(x))
287301
r_true = parse_as_ast(lambda x: (lambda y: y > 30)(x))

0 commit comments

Comments
 (0)