@@ -1006,31 +1006,47 @@ def evaluate_for(
10061006 return result
10071007
10081008
1009- def evaluate_nested_comp_helper (
1010- generators : list [ast .comprehension ],
1011- base_case_evaluator : Callable [[dict [str , Any ]], list ],
1009+ def _evaluate_comprehensions (
1010+ comprehensions : list [ast .comprehension ],
1011+ evaluate_element : Callable [[dict [str , Any ]], Any ],
10121012 state : dict [str , Any ],
10131013 static_tools : dict [str , Callable ],
10141014 custom_tools : dict [str , Callable ],
10151015 authorized_imports : list [str ],
1016- ) -> list :
1017- def inner_evaluate (index : int , current_state : dict [str , Any ]) -> list :
1018- if index >= len (generators ):
1019- return base_case_evaluator (current_state )
1020- generator = generators [index ]
1021- iter_value = evaluate_ast (generator .iter , current_state , static_tools , custom_tools , authorized_imports )
1022- results = []
1023- for value in iter_value :
1024- new_state = current_state .copy ()
1025- set_value (generator .target , value , new_state , static_tools , custom_tools , authorized_imports )
1026- if all (
1027- evaluate_ast (if_clause , new_state , static_tools , custom_tools , authorized_imports )
1028- for if_clause in generator .ifs
1029- ):
1030- results .extend (inner_evaluate (index + 1 , new_state ))
1031- return results
1016+ ) -> Generator [Any , None , None ]:
1017+ """
1018+ Recursively evaluate nested comprehensions and yields elements.
10321019
1033- return inner_evaluate (0 , state )
1020+ Args:
1021+ comprehensions (`list[ast.comprehension]`): Comprehensions to evaluate.
1022+ evaluate_element (`Callable`): Function that evaluates the final element when comprehensions are exhausted.
1023+ state (`dict[str, Any]`): Current evaluation state.
1024+ static_tools (`dict[str, Callable]`): Static tools.
1025+ custom_tools (`dict[str, Callable]`): Custom tools.
1026+ authorized_imports (`list[str]`): Authorized imports.
1027+
1028+ Yields:
1029+ `Any`: Individual elements produced by the comprehension
1030+ """
1031+ # Base case: no more comprehensions
1032+ if not comprehensions :
1033+ yield evaluate_element (state )
1034+ return
1035+ # Evaluate first comprehension
1036+ comprehension = comprehensions [0 ]
1037+ iter_value = evaluate_ast (comprehension .iter , state , static_tools , custom_tools , authorized_imports )
1038+ for value in iter_value :
1039+ new_state = state .copy ()
1040+ set_value (comprehension .target , value , new_state , static_tools , custom_tools , authorized_imports )
1041+ # Check all filter conditions
1042+ if all (
1043+ evaluate_ast (if_clause , new_state , static_tools , custom_tools , authorized_imports )
1044+ for if_clause in comprehension .ifs
1045+ ):
1046+ # Recurse with remaining comprehensions
1047+ yield from _evaluate_comprehensions (
1048+ comprehensions [1 :], evaluate_element , new_state , static_tools , custom_tools , authorized_imports
1049+ )
10341050
10351051
10361052def evaluate_listcomp (
@@ -1040,12 +1056,15 @@ def evaluate_listcomp(
10401056 custom_tools : dict [str , Callable ],
10411057 authorized_imports : list [str ],
10421058) -> list [Any ]:
1043- def base_case (current_state ):
1044- element = evaluate_ast (listcomp .elt , current_state , static_tools , custom_tools , authorized_imports )
1045- return [element ]
1046-
1047- return evaluate_nested_comp_helper (
1048- listcomp .generators , base_case , state , static_tools , custom_tools , authorized_imports
1059+ return list (
1060+ _evaluate_comprehensions (
1061+ listcomp .generators ,
1062+ lambda comp_state : evaluate_ast (listcomp .elt , comp_state , static_tools , custom_tools , authorized_imports ),
1063+ state ,
1064+ static_tools ,
1065+ custom_tools ,
1066+ authorized_imports ,
1067+ )
10491068 )
10501069
10511070
@@ -1056,13 +1075,14 @@ def evaluate_setcomp(
10561075 custom_tools : dict [str , Callable ],
10571076 authorized_imports : list [str ],
10581077) -> set [Any ]:
1059- def base_case (current_state ):
1060- element = evaluate_ast (setcomp .elt , current_state , static_tools , custom_tools , authorized_imports )
1061- return [element ]
1062-
10631078 return set (
1064- evaluate_nested_comp_helper (
1065- setcomp .generators , base_case , state , static_tools , custom_tools , authorized_imports
1079+ _evaluate_comprehensions (
1080+ setcomp .generators ,
1081+ lambda comp_state : evaluate_ast (setcomp .elt , comp_state , static_tools , custom_tools , authorized_imports ),
1082+ state ,
1083+ static_tools ,
1084+ custom_tools ,
1085+ authorized_imports ,
10661086 )
10671087 )
10681088
@@ -1074,14 +1094,17 @@ def evaluate_dictcomp(
10741094 custom_tools : dict [str , Callable ],
10751095 authorized_imports : list [str ],
10761096) -> dict [Any , Any ]:
1077- def base_case (current_state ):
1078- key = evaluate_ast (dictcomp .key , current_state , static_tools , custom_tools , authorized_imports )
1079- value = evaluate_ast (dictcomp .value , current_state , static_tools , custom_tools , authorized_imports )
1080- return [(key , value )]
1081-
10821097 return dict (
1083- evaluate_nested_comp_helper (
1084- dictcomp .generators , base_case , state , static_tools , custom_tools , authorized_imports
1098+ _evaluate_comprehensions (
1099+ dictcomp .generators ,
1100+ lambda comp_state : (
1101+ evaluate_ast (dictcomp .key , comp_state , static_tools , custom_tools , authorized_imports ),
1102+ evaluate_ast (dictcomp .value , comp_state , static_tools , custom_tools , authorized_imports ),
1103+ ),
1104+ state ,
1105+ static_tools ,
1106+ custom_tools ,
1107+ authorized_imports ,
10851108 )
10861109 )
10871110
0 commit comments