diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index f8ad68a3a..49c37beab 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -1006,31 +1006,47 @@ def evaluate_for( return result -def evaluate_nested_comp_helper( - generators: list[ast.comprehension], - base_case_evaluator: Callable[[dict[str, Any]], list], +def _evaluate_comprehensions( + comprehensions: list[ast.comprehension], + evaluate_element: Callable[[dict[str, Any]], Any], state: dict[str, Any], static_tools: dict[str, Callable], custom_tools: dict[str, Callable], authorized_imports: list[str], -) -> list: - def inner_evaluate(index: int, current_state: dict[str, Any]) -> list: - if index >= len(generators): - return base_case_evaluator(current_state) - generator = generators[index] - iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools, authorized_imports) - results = [] - for value in iter_value: - new_state = current_state.copy() - set_value(generator.target, value, new_state, static_tools, custom_tools, authorized_imports) - if all( - evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) - for if_clause in generator.ifs - ): - results.extend(inner_evaluate(index + 1, new_state)) - return results +) -> Generator[Any, None, None]: + """ + Recursively evaluate nested comprehensions and yields elements. - return inner_evaluate(0, state) + Args: + comprehensions (`list[ast.comprehension]`): Comprehensions to evaluate. + evaluate_element (`Callable`): Function that evaluates the final element when comprehensions are exhausted. + state (`dict[str, Any]`): Current evaluation state. + static_tools (`dict[str, Callable]`): Static tools. + custom_tools (`dict[str, Callable]`): Custom tools. + authorized_imports (`list[str]`): Authorized imports. + + Yields: + `Any`: Individual elements produced by the comprehension + """ + # Base case: no more comprehensions + if not comprehensions: + yield evaluate_element(state) + return + # Evaluate first comprehension + comprehension = comprehensions[0] + iter_value = evaluate_ast(comprehension.iter, state, static_tools, custom_tools, authorized_imports) + for value in iter_value: + new_state = state.copy() + set_value(comprehension.target, value, new_state, static_tools, custom_tools, authorized_imports) + # Check all filter conditions + if all( + evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) + for if_clause in comprehension.ifs + ): + # Recurse with remaining comprehensions + yield from _evaluate_comprehensions( + comprehensions[1:], evaluate_element, new_state, static_tools, custom_tools, authorized_imports + ) def evaluate_listcomp( @@ -1040,12 +1056,15 @@ def evaluate_listcomp( custom_tools: dict[str, Callable], authorized_imports: list[str], ) -> list[Any]: - def base_case(current_state): - element = evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools, authorized_imports) - return [element] - - return evaluate_nested_comp_helper( - listcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports + return list( + _evaluate_comprehensions( + listcomp.generators, + lambda comp_state: evaluate_ast(listcomp.elt, comp_state, static_tools, custom_tools, authorized_imports), + state, + static_tools, + custom_tools, + authorized_imports, + ) ) @@ -1056,13 +1075,14 @@ def evaluate_setcomp( custom_tools: dict[str, Callable], authorized_imports: list[str], ) -> set[Any]: - def base_case(current_state): - element = evaluate_ast(setcomp.elt, current_state, static_tools, custom_tools, authorized_imports) - return [element] - return set( - evaluate_nested_comp_helper( - setcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports + _evaluate_comprehensions( + setcomp.generators, + lambda comp_state: evaluate_ast(setcomp.elt, comp_state, static_tools, custom_tools, authorized_imports), + state, + static_tools, + custom_tools, + authorized_imports, ) ) @@ -1074,14 +1094,17 @@ def evaluate_dictcomp( custom_tools: dict[str, Callable], authorized_imports: list[str], ) -> dict[Any, Any]: - def base_case(current_state): - key = evaluate_ast(dictcomp.key, current_state, static_tools, custom_tools, authorized_imports) - value = evaluate_ast(dictcomp.value, current_state, static_tools, custom_tools, authorized_imports) - return [(key, value)] - return dict( - evaluate_nested_comp_helper( - dictcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports + _evaluate_comprehensions( + dictcomp.generators, + lambda comp_state: ( + evaluate_ast(dictcomp.key, comp_state, static_tools, custom_tools, authorized_imports), + evaluate_ast(dictcomp.value, comp_state, static_tools, custom_tools, authorized_imports), + ), + state, + static_tools, + custom_tools, + authorized_imports, ) )