Skip to content

Commit 2e4b59e

Browse files
authored
Use Function Names as Clues Parsing Lambda's (#116)
* Make sure we use the fact we are after a `Where` or `Select` when parsing multiple `lambda`'s on a line * Improve the error message when we can't tell the difference between mtuple lambda's. Fixes #115
1 parent ace1464 commit 2e4b59e

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

func_adl/object_stream.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def SelectMany(
110110
from func_adl.type_based_replacement import remap_from_lambda
111111

112112
n_stream, n_ast, rtn_type = remap_from_lambda(
113-
self, _local_simplification(parse_as_ast(func))
113+
self, _local_simplification(parse_as_ast(func, "SelectMany"))
114114
)
115115
return ObjectStream[S](
116116
function_call("SelectMany", [n_stream.query_ast, cast(ast.AST, n_ast)]),
@@ -136,7 +136,9 @@ def Select(self, f: Union[str, ast.Lambda, Callable[[T], S]]) -> ObjectStream[S]
136136
"""
137137
from func_adl.type_based_replacement import remap_from_lambda
138138

139-
n_stream, n_ast, rtn_type = remap_from_lambda(self, _local_simplification(parse_as_ast(f)))
139+
n_stream, n_ast, rtn_type = remap_from_lambda(
140+
self, _local_simplification(parse_as_ast(f, "Select"))
141+
)
140142
return ObjectStream[S](
141143
function_call("Select", [n_stream.query_ast, cast(ast.AST, n_ast)]), rtn_type
142144
)
@@ -160,7 +162,7 @@ def Where(self, filter: Union[str, ast.Lambda, Callable[[T], bool]]) -> ObjectSt
160162
from func_adl.type_based_replacement import remap_from_lambda
161163

162164
n_stream, n_ast, rtn_type = remap_from_lambda(
163-
self, _local_simplification(parse_as_ast(filter))
165+
self, _local_simplification(parse_as_ast(filter, "Where"))
164166
)
165167
if rtn_type != bool:
166168
raise ValueError(f"The Where filter must return a boolean (not {rtn_type})")

func_adl/util_ast.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,10 +647,14 @@ def lambda_arg_list(lda: ast.Lambda) -> List[str]:
647647

648648
if len(good_lambdas) > 1:
649649
raise ValueError(
650-
"Found multiple calls to on same line"
650+
"Found multiple calls on same line"
651651
+ ("" if caller_name is None else f" for {caller_name}")
652652
+ " - split the calls across "
653-
"lines or change lambda argument names so they are different."
653+
"""lines or change lambda argument names so they are different. For example change:
654+
df.Select(lambda x: x + 1).Select(lambda x: x + 2)
655+
to:
656+
df.Select(lambda x: x + 1).Select(lambda y: y + 2)
657+
"""
654658
)
655659

656660
lda = good_lambdas[0]

tests/test_object_stream.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ def test_simple_query():
7575
assert isinstance(r, ast.AST)
7676

7777

78+
def test_simple_query_one_line():
79+
"""Make sure we parse 2 functions on one line correctly"""
80+
r = my_event().Select(lambda e: e.met).Where(lambda e: e > 10).value()
81+
assert isinstance(r, ast.AST)
82+
83+
7884
def test_two_simple_query():
7985
r1 = (
8086
my_event()

tests/test_util_ast.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,38 @@ def value(self):
682682
assert "uncalibrated_collection" in ast.dump(found[0])
683683

684684

685+
def test_parse_select_where():
686+
"Common lambas with different parent functions on one line - found in wild"
687+
688+
found = []
689+
690+
class my_obj:
691+
def Where(self, x: Callable):
692+
found.append(parse_as_ast(x, "Where"))
693+
return self
694+
695+
def Select(self, x: Callable):
696+
found.append(parse_as_ast(x, "Select"))
697+
return self
698+
699+
def AsAwkwardArray(self, stuff: str):
700+
return self
701+
702+
def value(self):
703+
return self
704+
705+
jets_pflow_name = "hi"
706+
ds_dijet = my_obj()
707+
708+
# fmt: off
709+
jets_pflow = (
710+
ds_dijet.Select(lambda e: e.met).Where(lambda e: e > 100)
711+
)
712+
# fmt: on
713+
assert jets_pflow is not None # Just to keep flake8 happy without adding a noqa above.
714+
assert "met" in ast.dump(found[0])
715+
716+
685717
def test_parse_multiline_lambda_ok_with_one_as_arg():
686718
"Make sure we can properly parse a multi-line lambda - but now with argument"
687719

0 commit comments

Comments
 (0)