1919from collections import defaultdict , deque
2020import inspect
2121import textwrap
22- from typing import Any , TYPE_CHECKING
22+ from typing import TYPE_CHECKING , Any
2323
2424from typing_extensions import TypeIs
2525
26- from crewai .flow .constants import OR_CONDITION , AND_CONDITION
26+ from crewai .flow .constants import AND_CONDITION , OR_CONDITION
2727from crewai .flow .flow_wrappers import (
2828 FlowCondition ,
2929 FlowConditions ,
3333from crewai .flow .types import FlowMethodCallable , FlowMethodName
3434from crewai .utilities .printer import Printer
3535
36+
3637if TYPE_CHECKING :
3738 from crewai .flow .flow import Flow
3839
3940_printer = Printer ()
4041
4142
4243def get_possible_return_constants (function : Any ) -> list [str ] | None :
44+ """Extract possible string return values from a function using AST parsing.
45+
46+ This function analyzes the source code of a router method to identify
47+ all possible string values it might return. It handles:
48+ - Direct string literals: return "value"
49+ - Variable assignments: x = "value"; return x
50+ - Dictionary lookups: d = {"k": "v"}; return d[key]
51+ - Conditional returns: return "a" if cond else "b"
52+ - State attributes: return self.state.attr (infers from class context)
53+
54+ Args:
55+ function: The function to analyze.
56+
57+ Returns:
58+ List of possible string return values, or None if analysis fails.
59+ """
4360 try :
4461 source = inspect .getsource (function )
4562 except OSError :
@@ -82,6 +99,7 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
8299 return_values : set [str ] = set ()
83100 dict_definitions : dict [str , list [str ]] = {}
84101 variable_values : dict [str , list [str ]] = {}
102+ state_attribute_values : dict [str , list [str ]] = {}
85103
86104 def extract_string_constants (node : ast .expr ) -> list [str ]:
87105 """Recursively extract all string constants from an AST node."""
@@ -91,6 +109,17 @@ def extract_string_constants(node: ast.expr) -> list[str]:
91109 elif isinstance (node , ast .IfExp ):
92110 strings .extend (extract_string_constants (node .body ))
93111 strings .extend (extract_string_constants (node .orelse ))
112+ elif isinstance (node , ast .Call ):
113+ if (
114+ isinstance (node .func , ast .Attribute )
115+ and node .func .attr == "get"
116+ and len (node .args ) >= 2
117+ ):
118+ default_arg = node .args [1 ]
119+ if isinstance (default_arg , ast .Constant ) and isinstance (
120+ default_arg .value , str
121+ ):
122+ strings .append (default_arg .value )
94123 return strings
95124
96125 class VariableAssignmentVisitor (ast .NodeVisitor ):
@@ -124,6 +153,22 @@ def visit_Assign(self, node: ast.Assign) -> None:
124153
125154 self .generic_visit (node )
126155
156+ def get_attribute_chain (node : ast .expr ) -> str | None :
157+ """Extract the full attribute chain from an AST node.
158+
159+ Examples:
160+ self.state.run_type -> "self.state.run_type"
161+ x.y.z -> "x.y.z"
162+ simple_var -> "simple_var"
163+ """
164+ if isinstance (node , ast .Name ):
165+ return node .id
166+ if isinstance (node , ast .Attribute ):
167+ base = get_attribute_chain (node .value )
168+ if base :
169+ return f"{ base } .{ node .attr } "
170+ return None
171+
127172 class ReturnVisitor (ast .NodeVisitor ):
128173 def visit_Return (self , node : ast .Return ) -> None :
129174 if (
@@ -139,21 +184,94 @@ def visit_Return(self, node: ast.Return) -> None:
139184 for v in dict_definitions [var_name_dict ]:
140185 return_values .add (v )
141186 elif node .value :
142- var_name_ret : str | None = None
143- if isinstance (node .value , ast .Name ):
144- var_name_ret = node .value .id
145- elif isinstance (node .value , ast .Attribute ):
146- var_name_ret = f"{ node .value .value .id if isinstance (node .value .value , ast .Name ) else '_' } .{ node .value .attr } "
187+ var_name_ret = get_attribute_chain (node .value )
147188
148189 if var_name_ret and var_name_ret in variable_values :
149190 for v in variable_values [var_name_ret ]:
150191 return_values .add (v )
192+ elif var_name_ret and var_name_ret in state_attribute_values :
193+ for v in state_attribute_values [var_name_ret ]:
194+ return_values .add (v )
151195
152196 self .generic_visit (node )
153197
154198 def visit_If (self , node : ast .If ) -> None :
155199 self .generic_visit (node )
156200
201+ # Try to get the class context to infer state attribute values
202+ try :
203+ if hasattr (function , "__self__" ):
204+ # Method is bound, get the class
205+ class_obj = function .__self__ .__class__
206+ elif hasattr (function , "__qualname__" ) and "." in function .__qualname__ :
207+ # Method is unbound but we can try to get class from module
208+ class_name = function .__qualname__ .rsplit ("." , 1 )[0 ]
209+ if hasattr (function , "__globals__" ):
210+ class_obj = function .__globals__ .get (class_name )
211+ else :
212+ class_obj = None
213+ else :
214+ class_obj = None
215+
216+ if class_obj is not None :
217+ try :
218+ class_source = inspect .getsource (class_obj )
219+ class_source = textwrap .dedent (class_source )
220+ class_ast = ast .parse (class_source )
221+
222+ # Look for comparisons and assignments involving state attributes
223+ class StateAttributeVisitor (ast .NodeVisitor ):
224+ def visit_Compare (self , node : ast .Compare ) -> None :
225+ """Find comparisons like: self.state.attr == "value" """
226+ left_attr = get_attribute_chain (node .left )
227+
228+ if left_attr :
229+ for comparator in node .comparators :
230+ if isinstance (comparator , ast .Constant ) and isinstance (
231+ comparator .value , str
232+ ):
233+ if left_attr not in state_attribute_values :
234+ state_attribute_values [left_attr ] = []
235+ if (
236+ comparator .value
237+ not in state_attribute_values [left_attr ]
238+ ):
239+ state_attribute_values [left_attr ].append (
240+ comparator .value
241+ )
242+
243+ # Also check right side
244+ for comparator in node .comparators :
245+ right_attr = get_attribute_chain (comparator )
246+ if (
247+ right_attr
248+ and isinstance (node .left , ast .Constant )
249+ and isinstance (node .left .value , str )
250+ ):
251+ if right_attr not in state_attribute_values :
252+ state_attribute_values [right_attr ] = []
253+ if (
254+ node .left .value
255+ not in state_attribute_values [right_attr ]
256+ ):
257+ state_attribute_values [right_attr ].append (
258+ node .left .value
259+ )
260+
261+ self .generic_visit (node )
262+
263+ StateAttributeVisitor ().visit (class_ast )
264+ except Exception as e :
265+ _printer .print (
266+ f"Could not analyze class context for { function .__name__ } : { e } " ,
267+ color = "yellow" ,
268+ )
269+ except Exception as e :
270+ _printer .print (
271+ f"Could not introspect class for { function .__name__ } : { e } " ,
272+ color = "yellow" ,
273+ )
274+
157275 VariableAssignmentVisitor ().visit (code_ast )
158276 ReturnVisitor ().visit (code_ast )
159277
0 commit comments