33import ast
44import inspect
55import sys
6- from typing import Any , Callable , Dict , List , Optional , Union , cast
6+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union , cast
77
88# Some functions to enable backwards compatibility.
99# Capability may be degraded in older versions - particularly 3.6.
@@ -351,7 +351,25 @@ def global_getclosurevars(f: Callable) -> inspect.ClosureVars:
351351 return cv
352352
353353
354- def parse_as_ast (ast_source : Union [str , ast .AST , Callable ]) -> ast .Lambda :
354+ def _realign_indent (s : str ) -> str :
355+ """Move the first line to be at zero indent, and then apply same for everything
356+ below.
357+
358+ Args:
359+ s (str): The string with indents
360+
361+ Returns:
362+ str: Unindent string
363+ """
364+ lines = s .split ("\n " )
365+ spaces = len (lines [0 ]) - len (lines [0 ].lstrip ())
366+ stripped_lines = [ln [spaces :] for ln in lines ]
367+ return "\n " .join (stripped_lines )
368+
369+
370+ def parse_as_ast (
371+ ast_source : Union [str , ast .AST , Callable ], caller_name : Optional [str ] = None
372+ ) -> ast .Lambda :
355373 r"""Return an AST for a lambda function from several sources.
356374
357375 We are handed one of several things:
@@ -363,19 +381,25 @@ def parse_as_ast(ast_source: Union[str, ast.AST, Callable]) -> ast.Lambda:
363381
364382 Args:
365383 ast_source: An AST or text string that represnets the lambda.
384+ caller_name: The name of the function that the lambda is an arg to. If it
385+ is none, then it will attempt to scan the stack frame above to figure it out.
366386
367387 Returns:
368388 An ast starting from the Lambda AST node.
369389 """
370390 if callable (ast_source ):
371- source = inspect .getsource (ast_source ).strip ()
372-
373- # Look for the name of the calling function (e.g. 'Select' or 'Where')
374- caller_name = inspect .currentframe ().f_back .f_code .co_name # type: ignore
375- caller_idx = source .find (caller_name )
376- # If found, parse the string between the parentheses of the function call
377- if caller_idx > - 1 :
378- source = source [caller_idx + len (caller_name ) :]
391+ source = _realign_indent (inspect .getsource (ast_source ))
392+
393+ def find_next_lambda (method_name : str , source : str ) -> Tuple [Optional [str ], str ]:
394+ "Find the lambda starting at the name"
395+ caller_idx = source .find (method_name )
396+
397+ # If we couldn't find it, then we need to parse the whole thing.
398+ if caller_idx == - 1 :
399+ return None , source
400+
401+ source = source [caller_idx + len (method_name ) :]
402+
379403 i = 0
380404 open_count = 0
381405 while True :
@@ -387,42 +411,81 @@ def parse_as_ast(ast_source: Union[str, ast.AST, Callable]) -> ast.Lambda:
387411 if open_count == 0 :
388412 break
389413 i += 1
390- stem = source [i + 1 :]
391- new_line = stem .find ("\n " )
392- next_caller = stem .find (caller_name )
393- if next_caller > - 1 and (new_line < 0 or new_line > next_caller ):
394- raise ValueError (
395- f"Found two calls to { caller_name } on same line - " "split accross lines"
396- )
397- source = source [: i + 1 ]
398-
399- def parse (src : str ) -> Optional [ast .Module ]:
400- try :
401- return ast .parse (src )
402- except SyntaxError :
403- return None
404-
405- # Special case ending with a close parenthesis at the end of a line.
406- src_ast = parse (source )
407- if not src_ast and source .endswith (")" ):
408- src_ast = parse (source [:- 1 ])
409414
410- if not src_ast :
411- raise ValueError (f"Unable to recover source for function { ast_source } ." )
415+ lambda_source = source [: i + 1 ]
416+ remaining_source = source [i + 1 :]
417+
418+ return lambda_source , remaining_source
419+
420+ # Look for the name of the calling function (e.g. 'Select' or 'Where', etc.) and
421+ # find all the instances on this line.
422+ if caller_name is None :
423+ caller_name = inspect .currentframe ().f_back .f_code .co_name # type: ignore
424+
425+ found_lambdas : List [str ] = []
426+ while True :
427+ lambda_source , remaining_source = find_next_lambda (caller_name , source )
428+ if lambda_source is None :
429+ break
430+ source = remaining_source
431+ found_lambdas .append (lambda_source )
412432
413- # If this is a function, not a lambda, then we can morph and return that.
414- if len (src_ast .body ) == 1 and isinstance (src_ast .body [0 ], ast .FunctionDef ):
415- lda = rewrite_func_as_lambda (src_ast .body [0 ]) # type: ignore
433+ if len (found_lambdas ) == 0 :
434+ found_lambdas .append (source )
435+
436+ # Parse them as a lambda function
437+ def parse (src : str ) -> Optional [ast .Lambda ]:
438+ while True :
439+ try :
440+ a_module = ast .parse (src )
441+ # If this is a function, not a lambda, then we can morph and return that.
442+ if len (a_module .body ) == 1 and isinstance (a_module .body [0 ], ast .FunctionDef ):
443+ lda = rewrite_func_as_lambda (a_module .body [0 ]) # type: ignore
444+ else :
445+ lda = next (
446+ (node for node in ast .walk (a_module ) if isinstance (node , ast .Lambda )),
447+ None ,
448+ )
449+
450+ if lda is None :
451+ raise ValueError (
452+ f"Unable to recover source for function { ast_source } - '{ src } '."
453+ )
454+ return lda
455+ except SyntaxError :
456+ pass
457+ if src .endswith (")" ):
458+ src = src [:- 1 ]
459+ else :
460+ return None
461+
462+ parsed_lambdas = [parse (src ) for src in found_lambdas ]
463+
464+ # If we have more than one lambda, there are some tricks we can try - like argument names,
465+ # to see if they are different.
466+ src_ast : Optional [ast .Lambda ] = None
467+ if len (found_lambdas ) > 1 :
468+ caller_arg_list = inspect .getfullargspec (ast_source ).args
469+ for idx , p_lambda in enumerate (parsed_lambdas ):
470+ lambda_args = [a .arg for a in p_lambda .args .args ] # type: ignore
471+ if lambda_args == caller_arg_list :
472+ if src_ast is not None :
473+ raise ValueError (
474+ f"Found two calls to { caller_name } on same line - "
475+ "split accross lines or change lambda argument names so they "
476+ "are different."
477+ )
478+ src_ast = p_lambda
416479 else :
417- lda = next ((node for node in ast .walk (src_ast ) if isinstance (node , ast .Lambda )), None )
480+ assert len (found_lambdas ) == 1
481+ src_ast = parsed_lambdas [0 ]
418482
419- if lda is None :
420- raise ValueError (f"Unable to recover source for function { ast_source } ." )
483+ if not src_ast :
484+ raise ValueError (f"Unable to recover source for function { ast_source } ." )
421485
422486 # Since this is a function in python, we can look for lambda capture.
423487 call_args = global_getclosurevars (ast_source )
424-
425- return _rewrite_captured_vars (call_args ).visit (lda )
488+ return _rewrite_captured_vars (call_args ).visit (src_ast )
426489
427490 elif isinstance (ast_source , str ):
428491 a = ast .parse (ast_source .strip ()) # type: ignore
0 commit comments