Skip to content

Commit 64adb2b

Browse files
open-swe[bot]open-swesydney-runkleeyurtsev
authored
feat: Add context coercion for LangGraph runtime (langchain-ai#5736)
Fixes langchain-ai#5735 Implement context coercion functionality for LangGraph runtime to improve API usability. Key changes: - Added `_coerce_context` function in `pregel/main.py` - Supports coercion for: - Pydantic BaseModel - Dataclasses - TypedDict - Comprehensive test coverage added in `tests/test_runtime.py` - Handles edge cases like None context and missing fields The implementation allows users to pass dictionaries as context, which will be automatically converted to the expected schema type, making the API more flexible and user-friendly. --------- Co-authored-by: open-swe[bot] <[email protected]> Co-authored-by: Sydney Runkle <[email protected]> Co-authored-by: Sydney Runkle <[email protected]> Co-authored-by: Eugene Yurtsev <[email protected]>
1 parent e87f0fb commit 64adb2b

File tree

2 files changed

+315
-2
lines changed

2 files changed

+315
-2
lines changed

libs/langgraph/langgraph/pregel/main.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2571,7 +2571,7 @@ def stream_writer(c: Any) -> None:
25712571
config[CONF][CONFIG_KEY_DURABILITY] = durability_
25722572

25732573
runtime = Runtime(
2574-
context=context,
2574+
context=_coerce_context(self.context_schema, context),
25752575
store=store,
25762576
stream_writer=stream_writer,
25772577
previous=None,
@@ -2866,7 +2866,7 @@ def stream_writer(c: Any) -> None:
28662866
config[CONF][CONFIG_KEY_DURABILITY] = durability_
28672867

28682868
runtime = Runtime(
2869-
context=context,
2869+
context=_coerce_context(self.context_schema, context),
28702870
store=store,
28712871
stream_writer=stream_writer,
28722872
previous=None,
@@ -3224,3 +3224,33 @@ def _output(
32243224
yield (ns, payload)
32253225
else:
32263226
yield payload
3227+
3228+
3229+
def _coerce_context(
3230+
context_schema: type[ContextT] | None, context: Any
3231+
) -> ContextT | None:
3232+
"""Coerce context input to the appropriate schema type.
3233+
3234+
If context is a dict and context_schema is a dataclass or pydantic model, we coerce.
3235+
Else, we return the context as-is.
3236+
3237+
Args:
3238+
context_schema: The schema type to coerce to (BaseModel, dataclass, or TypedDict)
3239+
context: The context value to coerce
3240+
3241+
Returns:
3242+
The coerced context value or None if context is None
3243+
"""
3244+
if context is None:
3245+
return None
3246+
3247+
if context_schema is None:
3248+
return context
3249+
3250+
schema_is_class = issubclass(context_schema, BaseModel) or is_dataclass(
3251+
context_schema
3252+
)
3253+
if isinstance(context, dict) and schema_is_class:
3254+
return context_schema(**context) # type: ignore[misc]
3255+
3256+
return cast(ContextT, context)

libs/langgraph/tests/test_runtime.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from typing import Any
33

4+
import pytest
5+
from pydantic import BaseModel, ValidationError
46
from typing_extensions import TypedDict
57

68
from langgraph.graph import END, START, StateGraph
@@ -106,3 +108,284 @@ def main_node(state: State, runtime: Runtime[Context]):
106108
context = Context(username="Alice")
107109
result = graph.invoke({}, context=context)
108110
assert result == {"subgraph": "Alice!", "main": "Alice!"}
111+
112+
113+
def test_context_coercion_dataclass() -> None:
114+
"""Test that dict context is coerced to dataclass."""
115+
116+
@dataclass
117+
class Context:
118+
api_key: str
119+
timeout: int = 30
120+
121+
class State(TypedDict):
122+
message: str
123+
124+
def node_with_context(state: State, runtime: Runtime[Context]) -> dict[str, Any]:
125+
return {
126+
"message": f"api_key: {runtime.context.api_key}, timeout: {runtime.context.timeout}"
127+
}
128+
129+
graph = StateGraph(state_schema=State, context_schema=Context)
130+
graph.add_node("node", node_with_context)
131+
graph.add_edge(START, "node")
132+
graph.add_edge("node", END)
133+
compiled = graph.compile()
134+
135+
# Test dict coercion with all fields
136+
result = compiled.invoke(
137+
{"message": "test"}, context={"api_key": "sk_test", "timeout": 60}
138+
)
139+
assert result == {"message": "api_key: sk_test, timeout: 60"}
140+
141+
# Test dict coercion with default field
142+
result = compiled.invoke({"message": "test"}, context={"api_key": "sk_test2"})
143+
assert result == {"message": "api_key: sk_test2, timeout: 30"}
144+
145+
# Test with actual dataclass instance (should still work)
146+
result = compiled.invoke(
147+
{"message": "test"}, context=Context(api_key="sk_test3", timeout=90)
148+
)
149+
assert result == {"message": "api_key: sk_test3, timeout: 90"}
150+
151+
152+
def test_context_coercion_pydantic() -> None:
153+
"""Test that dict context is coerced to Pydantic model."""
154+
155+
class Context(BaseModel):
156+
api_key: str
157+
timeout: int = 30
158+
tags: list[str] = []
159+
160+
class State(TypedDict):
161+
message: str
162+
163+
def node_with_context(state: State, runtime: Runtime[Context]) -> dict[str, Any]:
164+
return {
165+
"message": f"api_key: {runtime.context.api_key}, timeout: {runtime.context.timeout}, tags: {runtime.context.tags}"
166+
}
167+
168+
graph = StateGraph(state_schema=State, context_schema=Context)
169+
graph.add_node("node", node_with_context)
170+
graph.add_edge(START, "node")
171+
graph.add_edge("node", END)
172+
compiled = graph.compile()
173+
174+
# Test dict coercion with all fields
175+
result = compiled.invoke(
176+
{"message": "test"},
177+
context={"api_key": "sk_test", "timeout": 60, "tags": ["prod", "v2"]},
178+
)
179+
assert result == {"message": "api_key: sk_test, timeout: 60, tags: ['prod', 'v2']"}
180+
181+
# Test dict coercion with defaults
182+
result = compiled.invoke({"message": "test"}, context={"api_key": "sk_test2"})
183+
assert result == {"message": "api_key: sk_test2, timeout: 30, tags: []"}
184+
185+
# Test with actual Pydantic instance (should still work)
186+
result = compiled.invoke(
187+
{"message": "test"},
188+
context=Context(api_key="sk_test3", timeout=90, tags=["test"]),
189+
)
190+
assert result == {"message": "api_key: sk_test3, timeout: 90, tags: ['test']"}
191+
192+
193+
def test_context_coercion_typeddict() -> None:
194+
"""Test that dict context with TypedDict schema passes through as-is."""
195+
196+
class Context(TypedDict):
197+
api_key: str
198+
timeout: int
199+
200+
class State(TypedDict):
201+
message: str
202+
203+
def node_with_context(state: State, runtime: Runtime[Context]) -> dict[str, Any]:
204+
# TypedDict context is just a dict at runtime
205+
return {
206+
"message": f"api_key: {runtime.context['api_key']}, timeout: {runtime.context['timeout']}"
207+
}
208+
209+
graph = StateGraph(state_schema=State, context_schema=Context)
210+
graph.add_node("node", node_with_context)
211+
graph.add_edge(START, "node")
212+
graph.add_edge("node", END)
213+
compiled = graph.compile()
214+
215+
# Test dict passes through for TypedDict
216+
result = compiled.invoke(
217+
{"message": "test"}, context={"api_key": "sk_test", "timeout": 60}
218+
)
219+
assert result == {"message": "api_key: sk_test, timeout: 60"}
220+
221+
222+
def test_context_coercion_none() -> None:
223+
"""Test that None context is handled properly."""
224+
225+
@dataclass
226+
class Context:
227+
api_key: str
228+
229+
class State(TypedDict):
230+
message: str
231+
232+
def node_without_context(state: State, runtime: Runtime[Context]) -> dict[str, Any]:
233+
# Should be None when no context provided
234+
return {"message": f"context is None: {runtime.context is None}"}
235+
236+
graph = StateGraph(state_schema=State, context_schema=Context)
237+
graph.add_node("node", node_without_context)
238+
graph.add_edge(START, "node")
239+
graph.add_edge("node", END)
240+
compiled = graph.compile()
241+
242+
# Test with None context
243+
result = compiled.invoke({"message": "test"}, context=None)
244+
assert result == {"message": "context is None: True"}
245+
246+
# Test without context parameter (defaults to None)
247+
result = compiled.invoke({"message": "test"})
248+
assert result == {"message": "context is None: True"}
249+
250+
251+
def test_context_coercion_errors() -> None:
252+
"""Test error handling for invalid context."""
253+
254+
@dataclass
255+
class Context:
256+
api_key: str # Required field
257+
258+
class State(TypedDict):
259+
message: str
260+
261+
def node_with_context(state: State, runtime: Runtime[Context]) -> dict[str, Any]:
262+
return {"message": "should not reach here"}
263+
264+
graph = StateGraph(state_schema=State, context_schema=Context)
265+
graph.add_node("node", node_with_context)
266+
graph.add_edge(START, "node")
267+
graph.add_edge("node", END)
268+
compiled = graph.compile()
269+
270+
# Test missing required field
271+
with pytest.raises(TypeError):
272+
compiled.invoke({"message": "test"}, context={"timeout": 60})
273+
274+
# Test invalid dict keys
275+
with pytest.raises(TypeError):
276+
compiled.invoke(
277+
{"message": "test"}, context={"api_key": "test", "invalid_field": "value"}
278+
)
279+
280+
281+
@pytest.mark.anyio
282+
async def test_context_coercion_async() -> None:
283+
"""Test context coercion with async methods."""
284+
285+
@dataclass
286+
class Context:
287+
api_key: str
288+
async_mode: bool = True
289+
290+
class State(TypedDict):
291+
message: str
292+
293+
async def async_node(state: State, runtime: Runtime[Context]) -> dict[str, Any]:
294+
return {
295+
"message": f"async api_key: {runtime.context.api_key}, async_mode: {runtime.context.async_mode}"
296+
}
297+
298+
graph = StateGraph(state_schema=State, context_schema=Context)
299+
graph.add_node("node", async_node)
300+
graph.add_edge(START, "node")
301+
graph.add_edge("node", END)
302+
compiled = graph.compile()
303+
304+
# Test dict coercion with ainvoke
305+
result = await compiled.ainvoke(
306+
{"message": "test"}, context={"api_key": "sk_async", "async_mode": False}
307+
)
308+
assert result == {"message": "async api_key: sk_async, async_mode: False"}
309+
310+
# Test dict coercion with astream
311+
chunks = []
312+
async for chunk in compiled.astream(
313+
{"message": "test"}, context={"api_key": "sk_stream"}
314+
):
315+
chunks.append(chunk)
316+
317+
# Find the chunk with our node output
318+
node_output = None
319+
for chunk in chunks:
320+
if "node" in chunk:
321+
node_output = chunk["node"]
322+
break
323+
324+
assert node_output == {"message": "async api_key: sk_stream, async_mode: True"}
325+
326+
327+
def test_context_coercion_stream() -> None:
328+
"""Test context coercion with sync stream method."""
329+
330+
@dataclass
331+
class Context:
332+
api_key: str
333+
stream_mode: str = "default"
334+
335+
class State(TypedDict):
336+
message: str
337+
338+
def node_with_context(state: State, runtime: Runtime[Context]) -> dict[str, Any]:
339+
return {
340+
"message": f"stream api_key: {runtime.context.api_key}, mode: {runtime.context.stream_mode}"
341+
}
342+
343+
graph = StateGraph(state_schema=State, context_schema=Context)
344+
graph.add_node("node", node_with_context)
345+
graph.add_edge(START, "node")
346+
graph.add_edge("node", END)
347+
compiled = graph.compile()
348+
349+
# Test dict coercion with stream
350+
chunks = []
351+
for chunk in compiled.stream(
352+
{"message": "test"}, context={"api_key": "sk_stream", "stream_mode": "fast"}
353+
):
354+
chunks.append(chunk)
355+
356+
# Find the chunk with our node output
357+
node_output = None
358+
for chunk in chunks:
359+
if "node" in chunk:
360+
node_output = chunk["node"]
361+
break
362+
363+
assert node_output == {"message": "stream api_key: sk_stream, mode: fast"}
364+
365+
366+
def test_context_coercion_pydantic_validation_errors() -> None:
367+
"""Test that Pydantic validation errors are raised."""
368+
369+
class Context(BaseModel):
370+
api_key: str
371+
timeout: int
372+
373+
class State(TypedDict):
374+
message: str
375+
376+
def node_with_context(state: State, runtime: Runtime[Context]) -> dict[str, Any]:
377+
return {
378+
"message": f"api_key: {runtime.context.api_key}, timeout: {runtime.context.timeout}"
379+
}
380+
381+
graph = StateGraph(state_schema=State, context_schema=Context)
382+
graph.add_node("node", node_with_context)
383+
graph.add_edge(START, "node")
384+
graph.add_edge("node", END)
385+
386+
compiled = graph.compile()
387+
388+
with pytest.raises(ValidationError):
389+
compiled.invoke(
390+
{"message": "test"}, context={"api_key": "sk_test", "timeout": "not_an_int"}
391+
)

0 commit comments

Comments
 (0)