Skip to content

Commit 4874dd9

Browse files
authored
add test (#19)
1 parent 0f10fe8 commit 4874dd9

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed

tests/test_swarm.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from typing import Optional
2+
3+
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
4+
from langchain_core.language_models.chat_models import BaseChatModel
5+
from langchain_core.messages import AIMessage, BaseMessage
6+
from langchain_core.outputs import ChatGeneration, ChatResult
7+
from langgraph.checkpoint.memory import MemorySaver
8+
from langgraph.prebuilt import create_react_agent
9+
10+
from langgraph_swarm import create_handoff_tool, create_swarm
11+
12+
13+
class FakeChatModel(BaseChatModel):
14+
idx: int = 0
15+
responses: list[BaseMessage]
16+
17+
@property
18+
def _llm_type(self) -> str:
19+
return "fake-tool-call-model"
20+
21+
def _generate(
22+
self,
23+
messages: list[BaseMessage],
24+
stop: Optional[list[str]] = None,
25+
run_manager: Optional[CallbackManagerForLLMRun] = None,
26+
**kwargs,
27+
) -> ChatResult:
28+
generation = ChatGeneration(message=self.responses[self.idx])
29+
self.idx += 1
30+
return ChatResult(generations=[generation])
31+
32+
def bind_tools(self, tools: list[any]) -> "FakeChatModel":
33+
return self
34+
35+
36+
def test_basic_swarm() -> None:
37+
# Create fake responses for the model
38+
recorded_messages = [
39+
AIMessage(
40+
content="",
41+
name="Alice",
42+
tool_calls=[
43+
{
44+
"name": "transfer_to_bob",
45+
"args": {},
46+
"id": "call_1LlFyjm6iIhDjdn7juWuPYr4",
47+
}
48+
],
49+
),
50+
AIMessage(
51+
content="Ahoy, matey! Bob the pirate be at yer service. What be ye needin' help with today on the high seas? Arrr!",
52+
name="Bob",
53+
),
54+
AIMessage(
55+
content="",
56+
name="Bob",
57+
tool_calls=[
58+
{
59+
"name": "transfer_to_alice",
60+
"args": {},
61+
"id": "call_T6pNmo2jTfZEK3a9avQ14f8Q",
62+
}
63+
],
64+
),
65+
AIMessage(
66+
content="",
67+
name="Alice",
68+
tool_calls=[
69+
{
70+
"name": "add",
71+
"args": {
72+
"a": 5,
73+
"b": 7,
74+
},
75+
"id": "call_4kLYO1amR2NfhAxfECkALCr1",
76+
}
77+
],
78+
),
79+
AIMessage(
80+
content="The sum of 5 and 7 is 12.",
81+
name="Alice",
82+
),
83+
]
84+
85+
model = FakeChatModel(responses=recorded_messages)
86+
87+
def add(a: int, b: int) -> int:
88+
"""Add two numbers"""
89+
return a + b
90+
91+
alice = create_react_agent(
92+
model,
93+
[add, create_handoff_tool(agent_name="Bob")],
94+
prompt="You are Alice, an addition expert.",
95+
name="Alice",
96+
)
97+
98+
bob = create_react_agent(
99+
model,
100+
[
101+
create_handoff_tool(
102+
agent_name="Alice", description="Transfer to Alice, she can help with math"
103+
)
104+
],
105+
prompt="You are Bob, you speak like a pirate.",
106+
name="Bob",
107+
)
108+
109+
checkpointer = MemorySaver()
110+
workflow = create_swarm([alice, bob], default_active_agent="Alice")
111+
app = workflow.compile(checkpointer=checkpointer)
112+
113+
config = {"configurable": {"thread_id": "1"}}
114+
turn_1 = app.invoke(
115+
{"messages": [{"role": "user", "content": "i'd like to speak to Bob"}]},
116+
config,
117+
)
118+
119+
# Verify turn 1 results
120+
assert len(turn_1["messages"]) == 4
121+
assert turn_1["messages"][-2].content == "Successfully transferred to Bob"
122+
assert turn_1["messages"][-1].content == recorded_messages[1].content
123+
assert turn_1["active_agent"] == "Bob"
124+
125+
turn_2 = app.invoke(
126+
{"messages": [{"role": "user", "content": "what's 5 + 7?"}]},
127+
config,
128+
)
129+
130+
# Verify turn 2 results
131+
assert len(turn_2["messages"]) == 10
132+
assert turn_2["messages"][-4].content == "Successfully transferred to Alice"
133+
assert turn_2["messages"][-2].content == "12"
134+
assert turn_2["messages"][-1].content == recorded_messages[4].content
135+
assert turn_2["active_agent"] == "Alice"

0 commit comments

Comments
 (0)