|
1 | | -from typing import Type, TypeVar |
2 | | - |
3 | 1 | from langgraph.graph import START, MessagesState, StateGraph |
4 | 2 | from langgraph.graph.state import CompiledStateGraph |
| 3 | +from typing_extensions import Literal, Optional, Type, TypeVar, Union, get_args, get_origin |
5 | 4 |
|
6 | 5 | from langgraph_swarm.handoff import get_handoff_destinations |
7 | 6 |
|
8 | 7 |
|
9 | 8 | class SwarmState(MessagesState): |
10 | 9 | """State schema for the multi-agent swarm.""" |
11 | 10 |
|
12 | | - active_agent: str |
| 11 | + # NOTE: this state field is optional and is not expected to be provided by the user. |
| 12 | + # If a user does provide it, the graph will start from the specified active agent. |
| 13 | + # If active agent is typed as a `str`, we turn it into enum of all active agent names. |
| 14 | + active_agent: Optional[str] |
13 | 15 |
|
14 | 16 |
|
15 | 17 | StateSchema = TypeVar("StateSchema", bound=SwarmState) |
16 | 18 | StateSchemaType = Type[StateSchema] |
17 | 19 |
|
18 | 20 |
|
| 21 | +def _update_state_schema_agent_names( |
| 22 | + state_schema: StateSchemaType, agent_names: list[str] |
| 23 | +) -> StateSchemaType: |
| 24 | + """Update the state schema to use Literal with agent names for 'active_agent'.""" |
| 25 | + |
| 26 | + active_agent_annotation = state_schema.__annotations__["active_agent"] |
| 27 | + |
| 28 | + # Check if the annotation is str or Optional[str] |
| 29 | + is_str_type = active_agent_annotation is str |
| 30 | + is_optional_str = ( |
| 31 | + get_origin(active_agent_annotation) is Union and get_args(active_agent_annotation)[0] is str |
| 32 | + ) |
| 33 | + |
| 34 | + # We only update if the 'active_agent' is a str or Optional[str] |
| 35 | + if not (is_str_type or is_optional_str): |
| 36 | + return state_schema |
| 37 | + |
| 38 | + updated_schema = type( |
| 39 | + f"{state_schema.__name__}", |
| 40 | + (state_schema,), |
| 41 | + {"__annotations__": {**state_schema.__annotations__}}, |
| 42 | + ) |
| 43 | + |
| 44 | + # Create the Literal type with agent names |
| 45 | + literal_type = Literal.__getitem__(tuple(agent_names)) |
| 46 | + |
| 47 | + # If it was Optional[str], make it Optional[Literal[...]] |
| 48 | + if is_optional_str: |
| 49 | + updated_schema.__annotations__["active_agent"] = Optional[literal_type] |
| 50 | + else: |
| 51 | + updated_schema.__annotations__["active_agent"] = literal_type |
| 52 | + |
| 53 | + return updated_schema |
| 54 | + |
| 55 | + |
19 | 56 | def add_active_agent_router( |
20 | 57 | builder: StateGraph, |
21 | 58 | *, |
@@ -64,13 +101,16 @@ def create_swarm( |
64 | 101 | Returns: |
65 | 102 | A multi-agent swarm StateGraph. |
66 | 103 | """ |
67 | | - if "active_agent" not in state_schema.__annotations__: |
| 104 | + active_agent_annotation = state_schema.__annotations__.get("active_agent") |
| 105 | + if active_agent_annotation is None: |
68 | 106 | raise ValueError("Missing required key 'active_agent' in state_schema") |
69 | 107 |
|
| 108 | + agent_names = [agent.name for agent in agents] |
| 109 | + state_schema = _update_state_schema_agent_names(state_schema, agent_names) |
70 | 110 | builder = StateGraph(state_schema) |
71 | 111 | add_active_agent_router( |
72 | 112 | builder, |
73 | | - route_to=[agent.name for agent in agents], |
| 113 | + route_to=agent_names, |
74 | 114 | default_active_agent=default_active_agent, |
75 | 115 | ) |
76 | 116 | for agent in agents: |
|
0 commit comments