|
18 | 18 | from collections.abc import Generator |
19 | 19 | from typing import Annotated, TypedDict |
20 | 20 |
|
21 | | -import langchain_core |
22 | 21 | import pytest |
23 | 22 |
|
24 | 23 | from langgraph.checkpoint.base import BaseCheckpointSaver |
@@ -63,10 +62,12 @@ def generate(state: JokeInput): |
63 | 62 | return {"jokes": [f"Joke about the year {state['subject']}"]} |
64 | 63 |
|
65 | 64 | def bump(state: JokeOutput): |
66 | | - return {"jokes": [state["jokes"][0] + " and another"]} |
| 65 | + return {"jokes": [state["jokes"][0] + " and the year before"]} |
67 | 66 |
|
68 | 67 | def bump_loop(state: JokeOutput): |
69 | | - return END if state["jokes"][0].endswith(" and another" * 10) else "bump" |
| 68 | + return ( |
| 69 | + END if state["jokes"][0].endswith(" and the year before" * 10) else "bump" |
| 70 | + ) |
70 | 71 |
|
71 | 72 | subgraph = StateGraph(JokeState, joke_subjects=JokeInput, output=JokeOutput) |
72 | 73 | subgraph.add_node("edit", edit) |
@@ -139,49 +140,26 @@ def disable_langsmith(): |
139 | 140 | os.environ["LANGCHAIN_API_KEY"] = "" |
140 | 141 |
|
141 | 142 |
|
142 | | -def test_sync( |
143 | | - joke_subjects, |
144 | | - checkpointer_mongodb, |
145 | | - checkpointer_memory, |
| 143 | +async def test_fanout( |
| 144 | + joke_subjects, checkpointer_mongodb, checkpointer_mongodb_async, checkpointer_memory |
146 | 145 | ) -> None: |
147 | 146 | checkpointers = { |
148 | 147 | "mongodb": checkpointer_mongodb, |
149 | | - "in_memory": checkpointer_memory, |
150 | | - } |
151 | | - |
152 | | - print("\n\nBegin test_sync") |
153 | | - for cname, checkpointer in checkpointers.items(): |
154 | | - assert isinstance(checkpointer, BaseCheckpointSaver) |
155 | | - |
156 | | - graphc = fanout_to_subgraph().compile(checkpointer=checkpointer) |
157 | | - assert isinstance(graphc.get_graph(), langchain_core.runnables.graph.Graph) |
158 | | - config = {"configurable": {"thread_id": cname}} |
159 | | - start = time.monotonic() |
160 | | - out = [c for c in graphc.stream(joke_subjects, config=config)] |
161 | | - assert len(out) == N_SUBJECTS |
162 | | - assert isinstance(out[0], dict) |
163 | | - assert out[0].keys() == {"generate_joke"} |
164 | | - assert set(out[0]["generate_joke"].keys()) == {"jokes"} |
165 | | - end = time.monotonic() |
166 | | - print(f"{cname}: {end - start:.4f} seconds") |
167 | | - |
168 | | - |
169 | | -async def test_async( |
170 | | - joke_subjects, checkpointer_mongodb_async, checkpointer_memory |
171 | | -) -> None: |
172 | | - checkpointers = { |
173 | 148 | "mongodb_async": checkpointer_mongodb_async, |
| 149 | + "in_memory": checkpointer_memory, |
174 | 150 | "in_memory_async": checkpointer_memory, |
175 | 151 | } |
176 | 152 |
|
177 | | - print("\n\nBegin test_async") |
178 | 153 | for cname, checkpointer in checkpointers.items(): |
179 | 154 | assert isinstance(checkpointer, BaseCheckpointSaver) |
180 | | - |
| 155 | + print(f"\n\nBegin test of {cname}") |
181 | 156 | graphc = (fanout_to_subgraph()).compile(checkpointer=checkpointer) |
182 | 157 | config = {"configurable": {"thread_id": cname}} |
183 | 158 | start = time.monotonic() |
184 | | - out = [c async for c in graphc.astream(joke_subjects, config=config)] |
| 159 | + if "async" in cname: |
| 160 | + out = [c async for c in graphc.astream(joke_subjects, config=config)] |
| 161 | + else: |
| 162 | + out = [c for c in graphc.stream(joke_subjects, config=config)] |
185 | 163 | assert len(out) == N_SUBJECTS |
186 | 164 | assert isinstance(out[0], dict) |
187 | 165 | assert out[0].keys() == {"generate_joke"} |
|
0 commit comments