Skip to content

Commit 5c68747

Browse files
committed
Linting
1 parent d1f825e commit 5c68747

File tree

2 files changed

+26
-21
lines changed

2 files changed

+26
-21
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class AsyncMongoDBSaver(BaseCheckpointSaver):
5555
>>> from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
5656
>>> from langgraph.graph import StateGraph
5757
58-
>>> async def main():
58+
>>> async def main() -> None:
5959
>>> builder = StateGraph(int)
6060
>>> builder.add_node("add_one", lambda x: x + 1)
6161
>>> builder.set_entry_point("add_one")

libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_highlevel_graph.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
import operator
1616
import os
1717
import time
18-
from collections.abc import Generator
19-
from typing import Annotated, TypedDict
18+
from collections.abc import AsyncGenerator, Generator
19+
from typing import Annotated
2020

2121
import pytest
22+
from langchain_core.runnables import RunnableConfig
23+
from typing_extensions import TypedDict
2224

2325
from langgraph.checkpoint.base import BaseCheckpointSaver
2426
from langgraph.checkpoint.memory import InMemorySaver
@@ -55,21 +57,21 @@ class JokeState(JokeInput, JokeOutput): ...
5557

5658
def fanout_to_subgraph() -> StateGraph:
5759
# Subgraph nodes create a joke.
58-
def edit(state: JokeInput):
60+
def edit(state: JokeInput) -> dict[str, str]:
5961
return {"subject": f"{state["subject"]}, and cats"}
6062

61-
def generate(state: JokeInput):
63+
def generate(state: JokeInput) -> dict[str, list[str]]:
6264
return {"jokes": [f"Joke about the year {state['subject']}"]}
6365

64-
def bump(state: JokeOutput):
66+
def bump(state: JokeOutput) -> dict[str, list[str]]:
6567
return {"jokes": [state["jokes"][0] + " and the year before"]}
6668

67-
def bump_loop(state: JokeOutput):
69+
def bump_loop(state: JokeOutput) -> JokeOutput:
6870
return (
6971
END if state["jokes"][0].endswith(" and the year before" * 10) else "bump"
7072
)
7173

72-
subgraph = StateGraph(JokeState, joke_subjects=JokeInput, output=JokeOutput)
74+
subgraph = StateGraph(JokeState)
7375
subgraph.add_node("edit", edit)
7476
subgraph.add_node("generate", generate)
7577
subgraph.add_node("bump", bump)
@@ -82,18 +84,18 @@ def bump_loop(state: JokeOutput):
8284
subgraphc = subgraph.compile()
8385

8486
# Parent graph maps the joke-generating subgraph.
85-
def fanout(state: OverallState):
87+
def fanout(state: OverallState) -> list:
8688
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
8789

8890
parentgraph = StateGraph(OverallState)
89-
parentgraph.add_node("generate_joke", subgraphc)
91+
parentgraph.add_node("generate_joke", subgraphc) # type: ignore[arg-type]
9092
parentgraph.add_conditional_edges(START, fanout)
9193
parentgraph.add_edge("generate_joke", END)
9294
return parentgraph
9395

9496

9597
@pytest.fixture
96-
def joke_subjects():
98+
def joke_subjects() -> OverallState:
9799
years = [str(2025 - 10 * i) for i in range(N_SUBJECTS)]
98100
return {"subjects": years}
99101

@@ -119,29 +121,32 @@ def checkpointer_mongodb() -> Generator[MongoDBSaver, None, None]:
119121

120122

121123
@pytest.fixture(scope="function")
122-
async def checkpointer_mongodb_async() -> Generator[AsyncMongoDBSaver, None, None]:
124+
async def checkpointer_mongodb_async() -> AsyncGenerator[AsyncMongoDBSaver, None]:
123125
async with AsyncMongoDBSaver.from_conn_string(
124126
MONGODB_URI,
125127
db_name=DB_NAME,
126128
checkpoint_collection_name=CHECKPOINT_CLXN_NAME + "_async",
127129
writes_collection_name=WRITES_CLXN_NAME + "_async",
128130
) as checkpointer:
129-
checkpointer.checkpoint_collection.delete_many({})
130-
checkpointer.writes_collection.delete_many({})
131+
await checkpointer.checkpoint_collection.delete_many({})
132+
await checkpointer.writes_collection.delete_many({})
131133
yield checkpointer
132-
checkpointer.checkpoint_collection.drop()
133-
checkpointer.writes_collection.drop()
134+
await checkpointer.checkpoint_collection.drop()
135+
await checkpointer.writes_collection.drop()
134136

135137

136138
@pytest.fixture(autouse=True)
137-
def disable_langsmith():
139+
def disable_langsmith() -> None:
138140
"""Disable LangSmith tracing for all tests"""
139141
os.environ["LANGCHAIN_TRACING_V2"] = "false"
140142
os.environ["LANGCHAIN_API_KEY"] = ""
141143

142144

143145
async def test_fanout(
144-
joke_subjects, checkpointer_mongodb, checkpointer_mongodb_async, checkpointer_memory
146+
joke_subjects: OverallState,
147+
checkpointer_mongodb: MongoDBSaver,
148+
checkpointer_mongodb_async: AsyncMongoDBSaver,
149+
checkpointer_memory: InMemorySaver,
145150
) -> None:
146151
checkpointers = {
147152
"mongodb": checkpointer_mongodb,
@@ -154,12 +159,12 @@ async def test_fanout(
154159
assert isinstance(checkpointer, BaseCheckpointSaver)
155160
print(f"\n\nBegin test of {cname}")
156161
graphc = (fanout_to_subgraph()).compile(checkpointer=checkpointer)
157-
config = {"configurable": {"thread_id": cname}}
162+
config: RunnableConfig = {"configurable": {"thread_id": cname}}
158163
start = time.monotonic()
159164
if "async" in cname:
160-
out = [c async for c in graphc.astream(joke_subjects, config=config)]
165+
out = [c async for c in graphc.astream(joke_subjects, config=config)] # type: ignore[arg-type]
161166
else:
162-
out = [c for c in graphc.stream(joke_subjects, config=config)]
167+
out = [c for c in graphc.stream(joke_subjects, config=config)] # type: ignore[arg-type]
163168
assert len(out) == N_SUBJECTS
164169
assert isinstance(out[0], dict)
165170
assert out[0].keys() == {"generate_joke"}

0 commit comments

Comments
 (0)