Skip to content

Commit c365420

Browse files
fix(langchain): use state schema as input schema to middleware nodes (#33023)
We want state schema as the input schema to middleware nodes because the conditional edges after these nodes need access to the full state. Also, we just generally want all state passed to middleware nodes, so we should be specifying this explicitly. If we don't, the state annotations used by users in their node signatures are used (so they might be missing fields).
1 parent 4d11877 commit c365420

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

libs/langchain_v1/langchain/agents/middleware_agent.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,17 @@ def create_agent( # noqa: PLR0915
226226
state_schemas = {m.state_schema for m in middleware}
227227
state_schemas.add(AgentState)
228228

229+
state_schema = _resolve_schema(state_schemas, "StateSchema", None)
230+
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
231+
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
232+
229233
# create graph, add nodes
230234
graph: StateGraph[
231235
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
232236
] = StateGraph(
233-
state_schema=_resolve_schema(state_schemas, "StateSchema", None),
234-
input_schema=_resolve_schema(state_schemas, "InputSchema", "input"),
235-
output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"),
237+
state_schema=state_schema,
238+
input_schema=input_schema,
239+
output_schema=output_schema,
236240
context_schema=context_schema,
237241
)
238242

@@ -417,16 +421,12 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
417421
for m in middleware:
418422
if m.__class__.before_model is not AgentMiddleware.before_model:
419423
graph.add_node(
420-
f"{m.__class__.__name__}.before_model",
421-
m.before_model,
422-
input_schema=m.state_schema,
424+
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
423425
)
424426

425427
if m.__class__.after_model is not AgentMiddleware.after_model:
426428
graph.add_node(
427-
f"{m.__class__.__name__}.after_model",
428-
m.after_model,
429-
input_schema=m.state_schema,
429+
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
430430
)
431431

432432
# add start edge

0 commit comments

Comments
 (0)