Skip to content

Commit dd08e84

Browse files
INTPYTHON-679 Create test of high-level StateGraph API (#177)
1 parent 259af2a commit dd08e84

File tree

2 files changed

+179
-1
lines changed

2 files changed

+179
-1
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class AsyncMongoDBSaver(BaseCheckpointSaver):
5555
>>> from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
5656
>>> from langgraph.graph import StateGraph
5757
58-
>>> async def main():
58+
>>> async def main() -> None:
5959
>>> builder = StateGraph(int)
6060
>>> builder.add_node("add_one", lambda x: x + 1)
6161
>>> builder.set_entry_point("add_one")
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
Based on LangGraph's Benchmarking script,
3+
https://github.com/langchain-ai/langgraph/blob/main/libs/langgraph/bench/fanout_to_subgraph.py,
4+
this pattern of joke generation is used often in the examples.
5+
The fanout here is performed by the list comprehension of [class:~langgraph.types.Send] calls.
6+
The effect of this is a map (fanout) workflow where the graph invokes
7+
the same node multiple times in parallel.
8+
The node here is a subgraph.
9+
The subgraph is linear, with a conditional edge 'bump_loop' that repeatably calls
10+
the node 'bump' until a condition is met.
11+
This test can be used for benchmarking.
12+
It also demonstrates the high-level API of subgraphs, add_conditional_edges, and Send.
13+
"""
14+
15+
import operator
16+
import os
17+
import time
18+
from collections.abc import AsyncGenerator, Generator
19+
from typing import Annotated
20+
21+
import pytest
22+
from langchain_core.runnables import RunnableConfig
23+
from typing_extensions import TypedDict
24+
25+
from langgraph.checkpoint.base import BaseCheckpointSaver
26+
from langgraph.checkpoint.memory import InMemorySaver
27+
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver, MongoDBSaver
28+
from langgraph.constants import START, Send
29+
from langgraph.graph import END, StateGraph
30+
31+
# --- Configuration ---
32+
MONGODB_URI = os.environ.get(
33+
"MONGODB_URI", "mongodb://localhost:27017?directConnection=true"
34+
)
35+
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
36+
CHECKPOINT_CLXN_NAME = "fanout_checkpoints"
37+
WRITES_CLXN_NAME = "fanout_writes"
38+
39+
N_SUBJECTS = 10 # increase for benchmarking
40+
41+
42+
class OverallState(TypedDict):
43+
subjects: list[str]
44+
jokes: Annotated[list[str], operator.add]
45+
46+
47+
class JokeInput(TypedDict):
48+
subject: str
49+
50+
51+
class JokeOutput(TypedDict):
52+
jokes: list[str]
53+
54+
55+
class JokeState(JokeInput, JokeOutput): ...
56+
57+
58+
def fanout_to_subgraph() -> StateGraph:
59+
# Subgraph nodes create a joke.
60+
def edit(state: JokeOutput) -> JokeOutput:
61+
return {"jokes": [f"{state['jokes'][0]}... and cats!"]}
62+
63+
def generate(state: JokeInput) -> JokeOutput:
64+
return {"jokes": [f"Joke about the year {state['subject']}"]}
65+
66+
def bump(state: JokeOutput) -> dict[str, list[str]]:
67+
return {"jokes": [state["jokes"][0] + " and the year before"]}
68+
69+
def bump_loop(state: JokeOutput) -> JokeOutput:
70+
return (
71+
"edit" if state["jokes"][0].endswith(" and the year before" * 3) else "bump"
72+
)
73+
74+
subgraph = StateGraph(JokeState)
75+
subgraph.add_node("edit", edit)
76+
subgraph.add_node("generate", generate)
77+
subgraph.add_node("bump", bump)
78+
subgraph.set_entry_point("generate")
79+
subgraph.add_edge("generate", "bump")
80+
subgraph.add_node("bump_loop", bump_loop)
81+
subgraph.add_conditional_edges("bump", bump_loop)
82+
subgraph.set_finish_point("edit")
83+
subgraphc = subgraph.compile()
84+
85+
# Parent graph maps the joke-generating subgraph.
86+
def fanout(state: OverallState) -> list:
87+
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
88+
89+
parentgraph = StateGraph(OverallState)
90+
parentgraph.add_node("generate_joke", subgraphc) # type: ignore[arg-type]
91+
parentgraph.add_conditional_edges(START, fanout)
92+
parentgraph.add_edge("generate_joke", END)
93+
return parentgraph
94+
95+
96+
@pytest.fixture
97+
def joke_subjects() -> OverallState:
98+
years = [str(2025 - 10 * i) for i in range(N_SUBJECTS)]
99+
return {"subjects": years}
100+
101+
102+
@pytest.fixture(scope="function")
103+
def checkpointer_memory() -> Generator[InMemorySaver, None, None]:
104+
yield InMemorySaver()
105+
106+
107+
@pytest.fixture(scope="function")
108+
def checkpointer_mongodb() -> Generator[MongoDBSaver, None, None]:
109+
with MongoDBSaver.from_conn_string(
110+
MONGODB_URI,
111+
db_name=DB_NAME,
112+
checkpoint_collection_name=CHECKPOINT_CLXN_NAME,
113+
writes_collection_name=WRITES_CLXN_NAME,
114+
) as checkpointer:
115+
checkpointer.checkpoint_collection.delete_many({})
116+
checkpointer.writes_collection.delete_many({})
117+
yield checkpointer
118+
checkpointer.checkpoint_collection.drop()
119+
checkpointer.writes_collection.drop()
120+
121+
122+
@pytest.fixture(scope="function")
123+
async def checkpointer_mongodb_async() -> AsyncGenerator[AsyncMongoDBSaver, None]:
124+
async with AsyncMongoDBSaver.from_conn_string(
125+
MONGODB_URI,
126+
db_name=DB_NAME,
127+
checkpoint_collection_name=CHECKPOINT_CLXN_NAME + "_async",
128+
writes_collection_name=WRITES_CLXN_NAME + "_async",
129+
) as checkpointer:
130+
await checkpointer.checkpoint_collection.delete_many({})
131+
await checkpointer.writes_collection.delete_many({})
132+
yield checkpointer
133+
await checkpointer.checkpoint_collection.drop()
134+
await checkpointer.writes_collection.drop()
135+
136+
137+
@pytest.fixture(autouse=True)
138+
def disable_langsmith() -> None:
139+
"""Disable LangSmith tracing for all tests"""
140+
os.environ["LANGCHAIN_TRACING_V2"] = "false"
141+
os.environ["LANGCHAIN_API_KEY"] = ""
142+
143+
144+
async def test_fanout(
145+
joke_subjects: OverallState,
146+
checkpointer_mongodb: MongoDBSaver,
147+
checkpointer_mongodb_async: AsyncMongoDBSaver,
148+
checkpointer_memory: InMemorySaver,
149+
) -> None:
150+
checkpointers = {
151+
"mongodb": checkpointer_mongodb,
152+
"mongodb_async": checkpointer_mongodb_async,
153+
"in_memory": checkpointer_memory,
154+
"in_memory_async": checkpointer_memory,
155+
}
156+
157+
for cname, checkpointer in checkpointers.items():
158+
assert isinstance(checkpointer, BaseCheckpointSaver)
159+
print(f"\n\nBegin test of {cname}")
160+
graphc = (fanout_to_subgraph()).compile(checkpointer=checkpointer)
161+
config: RunnableConfig = {"configurable": {"thread_id": cname}}
162+
start = time.monotonic()
163+
if "async" in cname:
164+
out = [c async for c in graphc.astream(joke_subjects, config=config)] # type: ignore[arg-type]
165+
else:
166+
out = [c for c in graphc.stream(joke_subjects, config=config)] # type: ignore[arg-type]
167+
assert len(out) == N_SUBJECTS
168+
assert isinstance(out[0], dict)
169+
assert out[0].keys() == {"generate_joke"}
170+
assert set(out[0]["generate_joke"].keys()) == {"jokes"}
171+
assert all(
172+
res["generate_joke"]["jokes"][0].endswith(
173+
f'{" and the year before" * 3}... and cats!'
174+
)
175+
for res in out
176+
)
177+
end = time.monotonic()
178+
print(f"{cname}: {end - start:.4f} seconds")

0 commit comments

Comments
 (0)