Skip to content

Commit f9e173d

Browse files
authored
More robust parsing of ast's from callables (#97)
* Better testing for edge cases * Deal with indents more naturally * Allow supplying method name to parser for cases where things are too deep
1 parent d96638b commit f9e173d

File tree

2 files changed

+206
-58
lines changed

2 files changed

+206
-58
lines changed

func_adl/util_ast.py

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ast
44
import inspect
55
import sys
6-
from typing import Any, Callable, Dict, List, Optional, Union, cast
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
77

88
# Some functions to enable backwards compatibility.
99
# Capability may be degraded in older versions - particularly 3.6.
@@ -351,7 +351,25 @@ def global_getclosurevars(f: Callable) -> inspect.ClosureVars:
351351
return cv
352352

353353

354-
def parse_as_ast(ast_source: Union[str, ast.AST, Callable]) -> ast.Lambda:
354+
def _realign_indent(s: str) -> str:
355+
"""Move the first line to be at zero indent, and then apply same for everything
356+
below.
357+
358+
Args:
359+
s (str): The string with indents
360+
361+
Returns:
362+
str: Unindent string
363+
"""
364+
lines = s.split("\n")
365+
spaces = len(lines[0]) - len(lines[0].lstrip())
366+
stripped_lines = [ln[spaces:] for ln in lines]
367+
return "\n".join(stripped_lines)
368+
369+
370+
def parse_as_ast(
371+
ast_source: Union[str, ast.AST, Callable], caller_name: Optional[str] = None
372+
) -> ast.Lambda:
355373
r"""Return an AST for a lambda function from several sources.
356374
357375
We are handed one of several things:
@@ -363,19 +381,25 @@ def parse_as_ast(ast_source: Union[str, ast.AST, Callable]) -> ast.Lambda:
363381
364382
Args:
365383
ast_source: An AST or text string that represnets the lambda.
384+
caller_name: The name of the function that the lambda is an arg to. If it
385+
is none, then it will attempt to scan the stack frame above to figure it out.
366386
367387
Returns:
368388
An ast starting from the Lambda AST node.
369389
"""
370390
if callable(ast_source):
371-
source = inspect.getsource(ast_source).strip()
372-
373-
# Look for the name of the calling function (e.g. 'Select' or 'Where')
374-
caller_name = inspect.currentframe().f_back.f_code.co_name # type: ignore
375-
caller_idx = source.find(caller_name)
376-
# If found, parse the string between the parentheses of the function call
377-
if caller_idx > -1:
378-
source = source[caller_idx + len(caller_name) :]
391+
source = _realign_indent(inspect.getsource(ast_source))
392+
393+
def find_next_lambda(method_name: str, source: str) -> Tuple[Optional[str], str]:
394+
"Find the lambda starting at the name"
395+
caller_idx = source.find(method_name)
396+
397+
# If we couldn't find it, then we need to parse the whole thing.
398+
if caller_idx == -1:
399+
return None, source
400+
401+
source = source[caller_idx + len(method_name) :]
402+
379403
i = 0
380404
open_count = 0
381405
while True:
@@ -387,42 +411,81 @@ def parse_as_ast(ast_source: Union[str, ast.AST, Callable]) -> ast.Lambda:
387411
if open_count == 0:
388412
break
389413
i += 1
390-
stem = source[i + 1 :]
391-
new_line = stem.find("\n")
392-
next_caller = stem.find(caller_name)
393-
if next_caller > -1 and (new_line < 0 or new_line > next_caller):
394-
raise ValueError(
395-
f"Found two calls to {caller_name} on same line - " "split accross lines"
396-
)
397-
source = source[: i + 1]
398-
399-
def parse(src: str) -> Optional[ast.Module]:
400-
try:
401-
return ast.parse(src)
402-
except SyntaxError:
403-
return None
404-
405-
# Special case ending with a close parenthesis at the end of a line.
406-
src_ast = parse(source)
407-
if not src_ast and source.endswith(")"):
408-
src_ast = parse(source[:-1])
409414

410-
if not src_ast:
411-
raise ValueError(f"Unable to recover source for function {ast_source}.")
415+
lambda_source = source[: i + 1]
416+
remaining_source = source[i + 1 :]
417+
418+
return lambda_source, remaining_source
419+
420+
# Look for the name of the calling function (e.g. 'Select' or 'Where', etc.) and
421+
# find all the instances on this line.
422+
if caller_name is None:
423+
caller_name = inspect.currentframe().f_back.f_code.co_name # type: ignore
424+
425+
found_lambdas: List[str] = []
426+
while True:
427+
lambda_source, remaining_source = find_next_lambda(caller_name, source)
428+
if lambda_source is None:
429+
break
430+
source = remaining_source
431+
found_lambdas.append(lambda_source)
412432

413-
# If this is a function, not a lambda, then we can morph and return that.
414-
if len(src_ast.body) == 1 and isinstance(src_ast.body[0], ast.FunctionDef):
415-
lda = rewrite_func_as_lambda(src_ast.body[0]) # type: ignore
433+
if len(found_lambdas) == 0:
434+
found_lambdas.append(source)
435+
436+
# Parse them as a lambda function
437+
def parse(src: str) -> Optional[ast.Lambda]:
438+
while True:
439+
try:
440+
a_module = ast.parse(src)
441+
# If this is a function, not a lambda, then we can morph and return that.
442+
if len(a_module.body) == 1 and isinstance(a_module.body[0], ast.FunctionDef):
443+
lda = rewrite_func_as_lambda(a_module.body[0]) # type: ignore
444+
else:
445+
lda = next(
446+
(node for node in ast.walk(a_module) if isinstance(node, ast.Lambda)),
447+
None,
448+
)
449+
450+
if lda is None:
451+
raise ValueError(
452+
f"Unable to recover source for function {ast_source} - '{src}'."
453+
)
454+
return lda
455+
except SyntaxError:
456+
pass
457+
if src.endswith(")"):
458+
src = src[:-1]
459+
else:
460+
return None
461+
462+
parsed_lambdas = [parse(src) for src in found_lambdas]
463+
464+
# If we have more than one lambda, there are some tricks we can try - like argument names,
465+
# to see if they are different.
466+
src_ast: Optional[ast.Lambda] = None
467+
if len(found_lambdas) > 1:
468+
caller_arg_list = inspect.getfullargspec(ast_source).args
469+
for idx, p_lambda in enumerate(parsed_lambdas):
470+
lambda_args = [a.arg for a in p_lambda.args.args] # type: ignore
471+
if lambda_args == caller_arg_list:
472+
if src_ast is not None:
473+
raise ValueError(
474+
f"Found two calls to {caller_name} on same line - "
475+
"split accross lines or change lambda argument names so they "
476+
"are different."
477+
)
478+
src_ast = p_lambda
416479
else:
417-
lda = next((node for node in ast.walk(src_ast) if isinstance(node, ast.Lambda)), None)
480+
assert len(found_lambdas) == 1
481+
src_ast = parsed_lambdas[0]
418482

419-
if lda is None:
420-
raise ValueError(f"Unable to recover source for function {ast_source}.")
483+
if not src_ast:
484+
raise ValueError(f"Unable to recover source for function {ast_source}.")
421485

422486
# Since this is a function in python, we can look for lambda capture.
423487
call_args = global_getclosurevars(ast_source)
424-
425-
return _rewrite_captured_vars(call_args).visit(lda)
488+
return _rewrite_captured_vars(call_args).visit(src_ast)
426489

427490
elif isinstance(ast_source, str):
428491
a = ast.parse(ast_source.strip()) # type: ignore

tests/test_util_ast.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from func_adl.util_ast import (
8+
_realign_indent,
89
as_ast,
910
function_call,
1011
lambda_args,
@@ -366,28 +367,92 @@ def test_known_global_function():
366367
assert "Name(id='global_doit_non_func'" in ast.dump(f)
367368

368369

369-
# TODO: This test is not compatible with the black formatter - it puts
370-
# the doit call on the line twice.
371-
# def test_parse_continues():
372-
# "Emulate the syntax you often find when you have a multistep query"
373-
# found = []
370+
def test_parse_continues():
371+
"Emulate the syntax you often find when you have a multistep query"
372+
found = []
373+
374+
class my_obj:
375+
def do_it(self, x: Callable):
376+
found.append(parse_as_ast(x))
377+
return self
378+
379+
(my_obj().do_it(lambda x: x + 1).do_it(lambda y: y * 2))
380+
381+
assert len(found) == 2
382+
l1, l2 = found
383+
assert isinstance(l1, ast.Lambda)
384+
assert isinstance(l1.body, ast.BinOp)
385+
assert isinstance(l1.body.op, ast.Add)
386+
387+
assert isinstance(l2, ast.Lambda)
388+
assert isinstance(l2.body, ast.BinOp)
389+
assert isinstance(l2.body.op, ast.Mult)
390+
391+
392+
def test_decorator_parse():
393+
"More general case"
394+
395+
seen_lambdas = []
396+
397+
def dec_func(x: Callable):
398+
def make_it(y: Callable):
399+
return y
400+
401+
seen_lambdas.append(parse_as_ast(x))
402+
return make_it
403+
404+
@dec_func(lambda y: y + 2)
405+
def doit(x):
406+
return x + 1
407+
408+
assert len(seen_lambdas) == 1
409+
l1 = seen_lambdas[0]
410+
assert isinstance(l1.body, ast.BinOp)
411+
assert isinstance(l1.body.op, ast.Add)
412+
374413

375-
# class my_obj:
376-
# def do_it(self, x: Callable):
377-
# found.append(parse_as_ast(x))
378-
# return self
414+
def test_indent_parse():
415+
"More general case"
379416

380-
# (my_obj().do_it(lambda x: x + 1).do_it(lambda y: y * 2))
417+
seen_funcs = []
381418

382-
# assert len(found) == 2
383-
# l1, l2 = found
384-
# assert isinstance(l1, ast.Lambda)
385-
# assert isinstance(l1.body, ast.BinOp)
386-
# assert isinstance(l1.body.op, ast.Add)
419+
class h:
420+
@staticmethod
421+
def dec_func(x: Callable):
422+
def make_it(y: Callable):
423+
return y
387424

388-
# assert isinstance(l2, ast.Lambda)
389-
# assert isinstance(l2.body, ast.BinOp)
390-
# assert isinstance(l2.body.op, ast.Mult)
425+
seen_funcs.append(x)
426+
return make_it
427+
428+
class yo_baby:
429+
@h.dec_func(lambda y: y + 2)
430+
def doit(self, x: int):
431+
...
432+
433+
assert len(seen_funcs) == 1
434+
l1 = parse_as_ast(seen_funcs[0], "dec_func")
435+
assert isinstance(l1.body, ast.BinOp)
436+
assert isinstance(l1.body.op, ast.Add)
437+
438+
439+
def test_two_deep_parse():
440+
"More general case"
441+
442+
seen_lambdas = []
443+
444+
def func_bottom(x: Callable):
445+
seen_lambdas.append(parse_as_ast(x))
446+
447+
def func_top(x: Callable):
448+
func_bottom(x)
449+
450+
func_top(lambda x: x + 1)
451+
452+
assert len(seen_lambdas) == 1
453+
l1 = seen_lambdas[0]
454+
assert isinstance(l1.body, ast.BinOp)
455+
assert isinstance(l1.body.op, ast.Add)
391456

392457

393458
def test_parse_continues_one_line():
@@ -400,7 +465,7 @@ def do_it(self, x: Callable):
400465
return self
401466

402467
with pytest.raises(Exception) as e:
403-
my_obj().do_it(lambda x: x + 1).do_it(lambda y: y * 2)
468+
my_obj().do_it(lambda x: x + 1).do_it(lambda x: x * 2)
404469

405470
assert "two" in str(e.value)
406471

@@ -416,3 +481,23 @@ def callback(a: ast.arg):
416481

417482
assert recoreded is not None
418483
assert 22 == ast.literal_eval(recoreded)
484+
485+
486+
def test_realign_no_indent():
487+
assert _realign_indent("test") == "test"
488+
489+
490+
def test_realign_indent_sp():
491+
assert _realign_indent(" test") == "test"
492+
493+
494+
def test_realign_indent_tab():
495+
assert _realign_indent("\ttest") == "test"
496+
497+
498+
def test_realign_indent_2lines():
499+
assert _realign_indent(" test()\n dude()") == "test()\ndude()"
500+
501+
502+
def test_realign_indent_2lines_uneven():
503+
assert _realign_indent(" test()\n dude()") == "test()\n dude()"

0 commit comments

Comments
 (0)