Skip to content

Commit f6aed97

Browse files
feat: allow non-ast plot routes
1 parent 40a2d38 commit f6aed97

File tree

9 files changed

+1741
-419
lines changed

9 files changed

+1741
-419
lines changed

lib/crewai/src/crewai/flow/flow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ def __new__(
428428
possible_returns = get_possible_return_constants(attr_value)
429429
if possible_returns:
430430
router_paths[attr_name] = possible_returns
431+
else:
432+
router_paths[attr_name] = []
431433

432434
cls._start_methods = start_methods # type: ignore[attr-defined]
433435
cls._listeners = listeners # type: ignore[attr-defined]

lib/crewai/src/crewai/flow/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
R = TypeVar("R", covariant=True)
2222

2323
FlowMethodName = NewType("FlowMethodName", str)
24+
FlowRouteName = NewType("FlowRouteName", str)
2425
PendingListenerKey = NewType(
2526
"PendingListenerKey",
2627
Annotated[str, "nested flow conditions use 'listener_name:object_id'"],

lib/crewai/src/crewai/flow/utils.py

Lines changed: 125 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
from collections import defaultdict, deque
2020
import inspect
2121
import textwrap
22-
from typing import Any, TYPE_CHECKING
22+
from typing import TYPE_CHECKING, Any
2323

2424
from typing_extensions import TypeIs
2525

26-
from crewai.flow.constants import OR_CONDITION, AND_CONDITION
26+
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
2727
from crewai.flow.flow_wrappers import (
2828
FlowCondition,
2929
FlowConditions,
@@ -33,13 +33,30 @@
3333
from crewai.flow.types import FlowMethodCallable, FlowMethodName
3434
from crewai.utilities.printer import Printer
3535

36+
3637
if TYPE_CHECKING:
3738
from crewai.flow.flow import Flow
3839

3940
_printer = Printer()
4041

4142

4243
def get_possible_return_constants(function: Any) -> list[str] | None:
44+
"""Extract possible string return values from a function using AST parsing.
45+
46+
This function analyzes the source code of a router method to identify
47+
all possible string values it might return. It handles:
48+
- Direct string literals: return "value"
49+
- Variable assignments: x = "value"; return x
50+
- Dictionary lookups: d = {"k": "v"}; return d[key]
51+
- Conditional returns: return "a" if cond else "b"
52+
- State attributes: return self.state.attr (infers from class context)
53+
54+
Args:
55+
function: The function to analyze.
56+
57+
Returns:
58+
List of possible string return values, or None if analysis fails.
59+
"""
4360
try:
4461
source = inspect.getsource(function)
4562
except OSError:
@@ -82,6 +99,7 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
8299
return_values: set[str] = set()
83100
dict_definitions: dict[str, list[str]] = {}
84101
variable_values: dict[str, list[str]] = {}
102+
state_attribute_values: dict[str, list[str]] = {}
85103

86104
def extract_string_constants(node: ast.expr) -> list[str]:
87105
"""Recursively extract all string constants from an AST node."""
@@ -91,6 +109,17 @@ def extract_string_constants(node: ast.expr) -> list[str]:
91109
elif isinstance(node, ast.IfExp):
92110
strings.extend(extract_string_constants(node.body))
93111
strings.extend(extract_string_constants(node.orelse))
112+
elif isinstance(node, ast.Call):
113+
if (
114+
isinstance(node.func, ast.Attribute)
115+
and node.func.attr == "get"
116+
and len(node.args) >= 2
117+
):
118+
default_arg = node.args[1]
119+
if isinstance(default_arg, ast.Constant) and isinstance(
120+
default_arg.value, str
121+
):
122+
strings.append(default_arg.value)
94123
return strings
95124

96125
class VariableAssignmentVisitor(ast.NodeVisitor):
@@ -124,6 +153,22 @@ def visit_Assign(self, node: ast.Assign) -> None:
124153

125154
self.generic_visit(node)
126155

156+
def get_attribute_chain(node: ast.expr) -> str | None:
157+
"""Extract the full attribute chain from an AST node.
158+
159+
Examples:
160+
self.state.run_type -> "self.state.run_type"
161+
x.y.z -> "x.y.z"
162+
simple_var -> "simple_var"
163+
"""
164+
if isinstance(node, ast.Name):
165+
return node.id
166+
if isinstance(node, ast.Attribute):
167+
base = get_attribute_chain(node.value)
168+
if base:
169+
return f"{base}.{node.attr}"
170+
return None
171+
127172
class ReturnVisitor(ast.NodeVisitor):
128173
def visit_Return(self, node: ast.Return) -> None:
129174
if (
@@ -139,21 +184,94 @@ def visit_Return(self, node: ast.Return) -> None:
139184
for v in dict_definitions[var_name_dict]:
140185
return_values.add(v)
141186
elif node.value:
142-
var_name_ret: str | None = None
143-
if isinstance(node.value, ast.Name):
144-
var_name_ret = node.value.id
145-
elif isinstance(node.value, ast.Attribute):
146-
var_name_ret = f"{node.value.value.id if isinstance(node.value.value, ast.Name) else '_'}.{node.value.attr}"
187+
var_name_ret = get_attribute_chain(node.value)
147188

148189
if var_name_ret and var_name_ret in variable_values:
149190
for v in variable_values[var_name_ret]:
150191
return_values.add(v)
192+
elif var_name_ret and var_name_ret in state_attribute_values:
193+
for v in state_attribute_values[var_name_ret]:
194+
return_values.add(v)
151195

152196
self.generic_visit(node)
153197

154198
def visit_If(self, node: ast.If) -> None:
155199
self.generic_visit(node)
156200

201+
# Try to get the class context to infer state attribute values
202+
try:
203+
if hasattr(function, "__self__"):
204+
# Method is bound, get the class
205+
class_obj = function.__self__.__class__
206+
elif hasattr(function, "__qualname__") and "." in function.__qualname__:
207+
# Method is unbound but we can try to get class from module
208+
class_name = function.__qualname__.rsplit(".", 1)[0]
209+
if hasattr(function, "__globals__"):
210+
class_obj = function.__globals__.get(class_name)
211+
else:
212+
class_obj = None
213+
else:
214+
class_obj = None
215+
216+
if class_obj is not None:
217+
try:
218+
class_source = inspect.getsource(class_obj)
219+
class_source = textwrap.dedent(class_source)
220+
class_ast = ast.parse(class_source)
221+
222+
# Look for comparisons and assignments involving state attributes
223+
class StateAttributeVisitor(ast.NodeVisitor):
224+
def visit_Compare(self, node: ast.Compare) -> None:
225+
"""Find comparisons like: self.state.attr == "value" """
226+
left_attr = get_attribute_chain(node.left)
227+
228+
if left_attr:
229+
for comparator in node.comparators:
230+
if isinstance(comparator, ast.Constant) and isinstance(
231+
comparator.value, str
232+
):
233+
if left_attr not in state_attribute_values:
234+
state_attribute_values[left_attr] = []
235+
if (
236+
comparator.value
237+
not in state_attribute_values[left_attr]
238+
):
239+
state_attribute_values[left_attr].append(
240+
comparator.value
241+
)
242+
243+
# Also check right side
244+
for comparator in node.comparators:
245+
right_attr = get_attribute_chain(comparator)
246+
if (
247+
right_attr
248+
and isinstance(node.left, ast.Constant)
249+
and isinstance(node.left.value, str)
250+
):
251+
if right_attr not in state_attribute_values:
252+
state_attribute_values[right_attr] = []
253+
if (
254+
node.left.value
255+
not in state_attribute_values[right_attr]
256+
):
257+
state_attribute_values[right_attr].append(
258+
node.left.value
259+
)
260+
261+
self.generic_visit(node)
262+
263+
StateAttributeVisitor().visit(class_ast)
264+
except Exception as e:
265+
_printer.print(
266+
f"Could not analyze class context for {function.__name__}: {e}",
267+
color="yellow",
268+
)
269+
except Exception as e:
270+
_printer.print(
271+
f"Could not introspect class for {function.__name__}: {e}",
272+
color="yellow",
273+
)
274+
157275
VariableAssignmentVisitor().visit(code_ast)
158276
ReturnVisitor().visit(code_ast)
159277

0 commit comments

Comments
 (0)