Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 62 additions & 39 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,31 +1006,47 @@ def evaluate_for(
return result


def evaluate_nested_comp_helper(
generators: list[ast.comprehension],
base_case_evaluator: Callable[[dict[str, Any]], list],
def _evaluate_comprehensions(
comprehensions: list[ast.comprehension],
evaluate_element: Callable[[dict[str, Any]], Any],
state: dict[str, Any],
static_tools: dict[str, Callable],
custom_tools: dict[str, Callable],
authorized_imports: list[str],
) -> list:
def inner_evaluate(index: int, current_state: dict[str, Any]) -> list:
if index >= len(generators):
return base_case_evaluator(current_state)
generator = generators[index]
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools, authorized_imports)
results = []
for value in iter_value:
new_state = current_state.copy()
set_value(generator.target, value, new_state, static_tools, custom_tools, authorized_imports)
if all(
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
for if_clause in generator.ifs
):
results.extend(inner_evaluate(index + 1, new_state))
return results
) -> Generator[Any, None, None]:
"""
Recursively evaluate nested comprehensions and yields elements.

return inner_evaluate(0, state)
Args:
comprehensions (`list[ast.comprehension]`): Comprehensions to evaluate.
evaluate_element (`Callable`): Function that evaluates the final element when comprehensions are exhausted.
state (`dict[str, Any]`): Current evaluation state.
static_tools (`dict[str, Callable]`): Static tools.
custom_tools (`dict[str, Callable]`): Custom tools.
authorized_imports (`list[str]`): Authorized imports.

Yields:
`Any`: Individual elements produced by the comprehension
"""
# Base case: no more comprehensions
if not comprehensions:
yield evaluate_element(state)
return
# Evaluate first comprehension
comprehension = comprehensions[0]
iter_value = evaluate_ast(comprehension.iter, state, static_tools, custom_tools, authorized_imports)
for value in iter_value:
new_state = state.copy()
set_value(comprehension.target, value, new_state, static_tools, custom_tools, authorized_imports)
# Check all filter conditions
if all(
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
for if_clause in comprehension.ifs
):
# Recurse with remaining comprehensions
yield from _evaluate_comprehensions(
comprehensions[1:], evaluate_element, new_state, static_tools, custom_tools, authorized_imports
)


def evaluate_listcomp(
Expand All @@ -1040,12 +1056,15 @@ def evaluate_listcomp(
custom_tools: dict[str, Callable],
authorized_imports: list[str],
) -> list[Any]:
def base_case(current_state):
element = evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools, authorized_imports)
return [element]

return evaluate_nested_comp_helper(
listcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports
return list(
_evaluate_comprehensions(
listcomp.generators,
lambda comp_state: evaluate_ast(listcomp.elt, comp_state, static_tools, custom_tools, authorized_imports),
state,
static_tools,
custom_tools,
authorized_imports,
)
)


Expand All @@ -1056,13 +1075,14 @@ def evaluate_setcomp(
custom_tools: dict[str, Callable],
authorized_imports: list[str],
) -> set[Any]:
def base_case(current_state):
element = evaluate_ast(setcomp.elt, current_state, static_tools, custom_tools, authorized_imports)
return [element]

return set(
evaluate_nested_comp_helper(
setcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports
_evaluate_comprehensions(
setcomp.generators,
lambda comp_state: evaluate_ast(setcomp.elt, comp_state, static_tools, custom_tools, authorized_imports),
state,
static_tools,
custom_tools,
authorized_imports,
)
)

Expand All @@ -1074,14 +1094,17 @@ def evaluate_dictcomp(
custom_tools: dict[str, Callable],
authorized_imports: list[str],
) -> dict[Any, Any]:
def base_case(current_state):
key = evaluate_ast(dictcomp.key, current_state, static_tools, custom_tools, authorized_imports)
value = evaluate_ast(dictcomp.value, current_state, static_tools, custom_tools, authorized_imports)
return [(key, value)]

return dict(
evaluate_nested_comp_helper(
dictcomp.generators, base_case, state, static_tools, custom_tools, authorized_imports
_evaluate_comprehensions(
dictcomp.generators,
lambda comp_state: (
evaluate_ast(dictcomp.key, comp_state, static_tools, custom_tools, authorized_imports),
evaluate_ast(dictcomp.value, comp_state, static_tools, custom_tools, authorized_imports),
),
state,
static_tools,
custom_tools,
authorized_imports,
)
)

Expand Down