Skip to content

Commit ed13b00

Browse files
authored
update active_agent typing, use agent names enum as default type (#48)
1 parent 4c2b858 commit ed13b00

File tree

1 file changed

+45
-5
lines changed

1 file changed

+45
-5
lines changed

langgraph_swarm/swarm.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,58 @@
1-
from typing import Type, TypeVar
2-
31
from langgraph.graph import START, MessagesState, StateGraph
42
from langgraph.graph.state import CompiledStateGraph
3+
from typing_extensions import Literal, Optional, Type, TypeVar, Union, get_args, get_origin
54

65
from langgraph_swarm.handoff import get_handoff_destinations
76

87

98
class SwarmState(MessagesState):
109
"""State schema for the multi-agent swarm."""
1110

12-
active_agent: str
11+
# NOTE: this state field is optional and is not expected to be provided by the user.
12+
# If a user does provide it, the graph will start from the specified active agent.
13+
# If active agent is typed as a `str`, we turn it into enum of all active agent names.
14+
active_agent: Optional[str]
1315

1416

1517
StateSchema = TypeVar("StateSchema", bound=SwarmState)
1618
StateSchemaType = Type[StateSchema]
1719

1820

21+
def _update_state_schema_agent_names(
22+
state_schema: StateSchemaType, agent_names: list[str]
23+
) -> StateSchemaType:
24+
"""Update the state schema to use Literal with agent names for 'active_agent'."""
25+
26+
active_agent_annotation = state_schema.__annotations__["active_agent"]
27+
28+
# Check if the annotation is str or Optional[str]
29+
is_str_type = active_agent_annotation is str
30+
is_optional_str = (
31+
get_origin(active_agent_annotation) is Union and get_args(active_agent_annotation)[0] is str
32+
)
33+
34+
# We only update if the 'active_agent' is a str or Optional[str]
35+
if not (is_str_type or is_optional_str):
36+
return state_schema
37+
38+
updated_schema = type(
39+
f"{state_schema.__name__}",
40+
(state_schema,),
41+
{"__annotations__": {**state_schema.__annotations__}},
42+
)
43+
44+
# Create the Literal type with agent names
45+
literal_type = Literal.__getitem__(tuple(agent_names))
46+
47+
# If it was Optional[str], make it Optional[Literal[...]]
48+
if is_optional_str:
49+
updated_schema.__annotations__["active_agent"] = Optional[literal_type]
50+
else:
51+
updated_schema.__annotations__["active_agent"] = literal_type
52+
53+
return updated_schema
54+
55+
1956
def add_active_agent_router(
2057
builder: StateGraph,
2158
*,
@@ -64,13 +101,16 @@ def create_swarm(
64101
Returns:
65102
A multi-agent swarm StateGraph.
66103
"""
67-
if "active_agent" not in state_schema.__annotations__:
104+
active_agent_annotation = state_schema.__annotations__.get("active_agent")
105+
if active_agent_annotation is None:
68106
raise ValueError("Missing required key 'active_agent' in state_schema")
69107

108+
agent_names = [agent.name for agent in agents]
109+
state_schema = _update_state_schema_agent_names(state_schema, agent_names)
70110
builder = StateGraph(state_schema)
71111
add_active_agent_router(
72112
builder,
73-
route_to=[agent.name for agent in agents],
113+
route_to=agent_names,
74114
default_active_agent=default_active_agent,
75115
)
76116
for agent in agents:

0 commit comments

Comments
 (0)