Skip to content

Commit dcbe97b

Browse files
committed
Initial test implementation.
1 parent 7604eda commit dcbe97b

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
"""
2+
Based on LangGraph's Benchmarking script,
3+
https://github.com/langchain-ai/langgraph/blob/main/libs/langgraph/bench/fanout_to_subgraph.py,
4+
this pattern of joke generation is used often in the examples.
5+
The fanout here is performed by the list comprehension of [class:~langgraph.types.Send] calls.
6+
The effect of this is a map (fanout) workflow where the graph invokes
7+
the same node multiple times in parallel.
8+
The node here is a subgraph.
9+
The subgraph is linear, with a conditional edge 'bump_loop' that repeatably calls
10+
the node 'bump' until a condition is met.
11+
This test can be used for benchmarking.
12+
It also demonstrates the high-level API of subgraphs, add_conditional_edges, and Send.
13+
"""
14+
15+
import operator
16+
import os
17+
import time
18+
from collections.abc import Generator
19+
from typing import Annotated, TypedDict
20+
21+
import langchain_core
22+
import pytest
23+
24+
from langgraph.checkpoint.base import BaseCheckpointSaver
25+
from langgraph.checkpoint.memory import InMemorySaver
26+
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver
27+
from langgraph.constants import START, Send
28+
from langgraph.graph import END, StateGraph
29+
30+
# --- Configuration ---
31+
MONGODB_URI = os.environ.get(
32+
"MONGODB_URI", "mongodb://localhost:27017?directConnection=true"
33+
)
34+
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
35+
CHECKPOINT_CLXN_NAME = "fanout_checkpoints"
36+
WRITES_CLXN_NAME = "fanout_writes"
37+
38+
N_SUBJECTS = 10 # increase for benchmarking
39+
40+
41+
class OverallState(TypedDict):
42+
subjects: list[str]
43+
jokes: Annotated[list[str], operator.add]
44+
45+
46+
class JokeInput(TypedDict):
47+
subject: str
48+
49+
50+
class JokeOutput(TypedDict):
51+
jokes: list[str]
52+
53+
54+
class JokeState(JokeInput, JokeOutput): ...
55+
56+
57+
@pytest.fixture
58+
def joke_subjects():
59+
years = [str(2025 - 10 * i) for i in range(N_SUBJECTS)]
60+
return {"subjects": years}
61+
62+
63+
@pytest.fixture(scope="function")
64+
def checkpointer_memory() -> Generator[InMemorySaver, None, None]:
65+
yield InMemorySaver()
66+
67+
68+
@pytest.fixture(scope="function")
69+
def checkpointer_mongodb() -> Generator[MongoDBSaver, None, None]:
70+
with MongoDBSaver.from_conn_string(
71+
MONGODB_URI,
72+
db_name=DB_NAME,
73+
checkpoint_collection_name=CHECKPOINT_CLXN_NAME,
74+
writes_collection_name=WRITES_CLXN_NAME,
75+
) as checkpointer:
76+
checkpointer.checkpoint_collection.delete_many({})
77+
checkpointer.writes_collection.delete_many({})
78+
yield checkpointer
79+
checkpointer.checkpoint_collection.drop()
80+
checkpointer.writes_collection.drop()
81+
82+
83+
@pytest.fixture(scope="function")
84+
async def checkpointer_mongodb_async() -> Generator[AsyncMongoDBSaver, None, None]:
85+
async with AsyncMongoDBSaver.from_conn_string(
86+
MONGODB_URI,
87+
db_name=DB_NAME,
88+
checkpoint_collection_name=CHECKPOINT_CLXN_NAME + "_async",
89+
writes_collection_name=WRITES_CLXN_NAME + "_async",
90+
) as checkpointer:
91+
checkpointer.checkpoint_collection.delete_many({})
92+
checkpointer.writes_collection.delete_many({})
93+
yield checkpointer
94+
checkpointer.checkpoint_collection.drop()
95+
checkpointer.writes_collection.drop()
96+
97+
98+
@pytest.fixture(autouse=True)
99+
def disable_langsmith():
100+
"""Disable LangSmith tracing for all tests"""
101+
os.environ["LANGCHAIN_TRACING_V2"] = "false"
102+
os.environ["LANGCHAIN_API_KEY"] = ""
103+
104+
105+
def test_sync(
106+
joke_subjects,
107+
checkpointer_mongodb,
108+
checkpointer_memory,
109+
) -> None:
110+
checkpointers = {
111+
"mongodb": checkpointer_mongodb,
112+
"in_memory": checkpointer_memory,
113+
}
114+
115+
def fanout_to_subgraph() -> StateGraph:
116+
# Subgraph nodes create a joke
117+
def edit(state: JokeInput):
118+
return {"subject": f"{state["subject"]}, and cats"}
119+
120+
def generate(state: JokeInput):
121+
return {"jokes": [f"Joke about the year {state['subject']}"]}
122+
123+
def bump(state: JokeOutput):
124+
return {"jokes": [state["jokes"][0] + " and another"]}
125+
126+
def bump_loop(state: JokeOutput):
127+
return END if state["jokes"][0].endswith(" and another" * 10) else "bump"
128+
129+
subgraph = StateGraph(JokeState, joke_subjects=JokeInput, output=JokeOutput)
130+
subgraph.add_node("edit", edit)
131+
subgraph.add_node("generate", generate)
132+
subgraph.add_node("bump", bump)
133+
subgraph.set_entry_point("edit")
134+
subgraph.add_edge("edit", "generate")
135+
subgraph.add_edge("generate", "bump")
136+
subgraph.add_node("bump_loop", bump_loop)
137+
subgraph.add_conditional_edges("bump", bump_loop)
138+
subgraph.set_finish_point("generate")
139+
subgraphc = subgraph.compile()
140+
141+
# parent graph maps the joke-generating subgraph
142+
def fanout(state: OverallState):
143+
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
144+
145+
parentgraph = StateGraph(OverallState)
146+
parentgraph.add_node("generate_joke", subgraphc)
147+
parentgraph.add_conditional_edges(START, fanout)
148+
parentgraph.add_edge("generate_joke", END)
149+
return parentgraph
150+
151+
print("\n\nBegin test_sync")
152+
for cname, checkpointer in checkpointers.items():
153+
assert isinstance(checkpointer, BaseCheckpointSaver)
154+
155+
graphc = fanout_to_subgraph().compile(checkpointer=checkpointer)
156+
assert isinstance(graphc.get_graph(), langchain_core.runnables.graph.Graph)
157+
config = {"configurable": {"thread_id": cname}}
158+
start = time.monotonic()
159+
out = [c for c in graphc.stream(joke_subjects, config=config)]
160+
assert len(out) == N_SUBJECTS
161+
assert isinstance(out[0], dict)
162+
assert out[0].keys() == {"generate_joke"}
163+
assert set(out[0]["generate_joke"].keys()) == {"jokes"}
164+
end = time.monotonic()
165+
print(f"{cname}: {end - start:.4f} seconds")
166+
167+
168+
async def test_async(
169+
joke_subjects, checkpointer_mongodb_async, checkpointer_memory
170+
) -> None:
171+
checkpointers = {
172+
"mongodb_async": checkpointer_mongodb_async,
173+
"in_memory_async": checkpointer_memory,
174+
}
175+
176+
async def fanout_to_subgraph() -> StateGraph:
177+
# Subgraph nodes create a joke
178+
async def edit(state: JokeInput):
179+
subject = state["subject"]
180+
return {"subject": f"{subject}, and cats"}
181+
182+
async def generate(state: JokeInput):
183+
return {"jokes": [f"Joke about the year {state['subject']}"]}
184+
185+
async def bump(state: JokeOutput):
186+
return {"jokes": [state["jokes"][0] + " and another"]}
187+
188+
async def bump_loop(state: JokeOutput):
189+
return END if state["jokes"][0].endswith(" and another" * 10) else "bump"
190+
191+
subgraph = StateGraph(JokeState, joke_subjects=JokeInput, output=JokeOutput)
192+
subgraph.add_node("edit", edit)
193+
subgraph.add_node("generate", generate)
194+
subgraph.add_node("bump", bump)
195+
subgraph.set_entry_point("edit")
196+
subgraph.add_edge("edit", "generate")
197+
subgraph.add_edge("generate", "bump")
198+
subgraph.add_conditional_edges("bump", bump_loop)
199+
subgraph.set_finish_point("generate")
200+
subgraphc = subgraph.compile()
201+
202+
# parent graph maps the joke-generating subgraph
203+
async def fanout(state: OverallState):
204+
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
205+
206+
parentgraph = StateGraph(OverallState)
207+
parentgraph.add_node("generate_joke", subgraphc)
208+
parentgraph.add_conditional_edges(START, fanout)
209+
parentgraph.add_edge("generate_joke", END)
210+
return parentgraph
211+
212+
print("\n\nBegin test_async")
213+
for cname, checkpointer in checkpointers.items():
214+
assert isinstance(checkpointer, BaseCheckpointSaver)
215+
216+
graphc = (await fanout_to_subgraph()).compile(checkpointer=checkpointer)
217+
config = {"configurable": {"thread_id": cname}}
218+
start = time.monotonic()
219+
out = [c async for c in graphc.astream(joke_subjects, config=config)]
220+
assert len(out) == N_SUBJECTS
221+
assert isinstance(out[0], dict)
222+
assert out[0].keys() == {"generate_joke"}
223+
assert set(out[0]["generate_joke"].keys()) == {"jokes"}
224+
end = time.monotonic()
225+
print(f"{cname}: {end - start:.4f} seconds")

0 commit comments

Comments
 (0)