Skip to content

Commit a4850b9

Browse files
Optimize comprehension evaluation with generator-based approach in LocalPythonExecutor (#1824)
1 parent 97963c9 commit a4850b9

File tree

1 file changed

+62
-39
lines changed

1 file changed

+62
-39
lines changed

src/smolagents/local_python_executor.py

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,31 +1006,47 @@ def evaluate_for(
10061006
return result
10071007

10081008

1009-
def evaluate_nested_comp_helper(
1010-
generators: list[ast.comprehension],
1011-
base_case_evaluator: Callable[[dict[str, Any]], list],
1009+
def _evaluate_comprehensions(
1010+
comprehensions: list[ast.comprehension],
1011+
evaluate_element: Callable[[dict[str, Any]], Any],
10121012
state: dict[str, Any],
10131013
static_tools: dict[str, Callable],
10141014
custom_tools: dict[str, Callable],
10151015
authorized_imports: list[str],
1016-
) -> list:
1017-
def inner_evaluate(index: int, current_state: dict[str, Any]) -> list:
1018-
if index >= len(generators):
1019-
return base_case_evaluator(current_state)
1020-
generator = generators[index]
1021-
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools, authorized_imports)
1022-
results = []
1023-
for value in iter_value:
1024-
new_state = current_state.copy()
1025-
set_value(generator.target, value, new_state, static_tools, custom_tools, authorized_imports)
1026-
if all(
1027-
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
1028-
for if_clause in generator.ifs
1029-
):
1030-
results.extend(inner_evaluate(index + 1, new_state))
1031-
return results
1016+
) -> Generator[Any, None, None]:
1017+
"""
1018+
Recursively evaluate nested comprehensions and yields elements.
10321019
1033-
return inner_evaluate(0, state)
1020+
Args:
1021+
comprehensions (`list[ast.comprehension]`): Comprehensions to evaluate.
1022+
evaluate_element (`Callable`): Function that evaluates the final element when comprehensions are exhausted.
1023+
state (`dict[str, Any]`): Current evaluation state.
1024+
static_tools (`dict[str, Callable]`): Static tools.
1025+
custom_tools (`dict[str, Callable]`): Custom tools.
1026+
authorized_imports (`list[str]`): Authorized imports.
1027+
1028+
Yields:
1029+
`Any`: Individual elements produced by the comprehension
1030+
"""
1031+
# Base case: no more comprehensions
1032+
if not comprehensions:
1033+
yield evaluate_element(state)
1034+
return
1035+
# Evaluate first comprehension
1036+
comprehension = comprehensions[0]
1037+
iter_value = evaluate_ast(comprehension.iter, state, static_tools, custom_tools, authorized_imports)
1038+
for value in iter_value:
1039+
new_state = state.copy()
1040+
set_value(comprehension.target, value, new_state, static_tools, custom_tools, authorized_imports)
1041+
# Check all filter conditions
1042+
if all(
1043+
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
1044+
for if_clause in comprehension.ifs
1045+
):
1046+
# Recurse with remaining comprehensions
1047+
yield from _evaluate_comprehensions(
1048+
comprehensions[1:], evaluate_element, new_state, static_tools, custom_tools, authorized_imports
1049+
)
10341050

10351051

10361052
def evaluate_listcomp(
@@ -1040,12 +1056,15 @@ def evaluate_listcomp(
10401056
custom_tools: dict[str, Callable],
10411057
authorized_imports: list[str],
10421058
) -> list[Any]:
1043-
def base_case(current_state):
1044-
element = evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools, authorized_imports)
1045-
return [element]
1046-
1047-
return evaluate_nested_comp_helper(
1048-
listcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports
1059+
return list(
1060+
_evaluate_comprehensions(
1061+
listcomp.generators,
1062+
lambda comp_state: evaluate_ast(listcomp.elt, comp_state, static_tools, custom_tools, authorized_imports),
1063+
state,
1064+
static_tools,
1065+
custom_tools,
1066+
authorized_imports,
1067+
)
10491068
)
10501069

10511070

@@ -1056,13 +1075,14 @@ def evaluate_setcomp(
10561075
custom_tools: dict[str, Callable],
10571076
authorized_imports: list[str],
10581077
) -> set[Any]:
1059-
def base_case(current_state):
1060-
element = evaluate_ast(setcomp.elt, current_state, static_tools, custom_tools, authorized_imports)
1061-
return [element]
1062-
10631078
return set(
1064-
evaluate_nested_comp_helper(
1065-
setcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports
1079+
_evaluate_comprehensions(
1080+
setcomp.generators,
1081+
lambda comp_state: evaluate_ast(setcomp.elt, comp_state, static_tools, custom_tools, authorized_imports),
1082+
state,
1083+
static_tools,
1084+
custom_tools,
1085+
authorized_imports,
10661086
)
10671087
)
10681088

@@ -1074,14 +1094,17 @@ def evaluate_dictcomp(
10741094
custom_tools: dict[str, Callable],
10751095
authorized_imports: list[str],
10761096
) -> dict[Any, Any]:
1077-
def base_case(current_state):
1078-
key = evaluate_ast(dictcomp.key, current_state, static_tools, custom_tools, authorized_imports)
1079-
value = evaluate_ast(dictcomp.value, current_state, static_tools, custom_tools, authorized_imports)
1080-
return [(key, value)]
1081-
10821097
return dict(
1083-
evaluate_nested_comp_helper(
1084-
dictcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports
1098+
_evaluate_comprehensions(
1099+
dictcomp.generators,
1100+
lambda comp_state: (
1101+
evaluate_ast(dictcomp.key, comp_state, static_tools, custom_tools, authorized_imports),
1102+
evaluate_ast(dictcomp.value, comp_state, static_tools, custom_tools, authorized_imports),
1103+
),
1104+
state,
1105+
static_tools,
1106+
custom_tools,
1107+
authorized_imports,
10851108
)
10861109
)
10871110

0 commit comments

Comments
 (0)