Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 96 additions & 7 deletions libs/langchain_v1/langchain/agents/middleware_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ModelRequest,
OmitFromSchema,
PublicAgentState,
StateT,
)
from langchain.agents.structured_output import (
MultipleStructuredOutputsError,
Expand Down Expand Up @@ -166,6 +167,86 @@
return False, ""


def _validate_schema_compatibility(
state_schemas: set[type],
user_state_schema: type | None,
middleware: Sequence[AgentMiddleware],
response_format: ResponseFormat | type | None,
) -> None:
"""Validate that user state schema is compatible with middleware schemas.

Args:
state_schemas: Set of middleware state schemas
user_state_schema: User-provided state schema to validate
middleware: List of middleware instances for error messaging
response_format: Response format to check if structured_response field is required

Raises:
ValueError: If user schema is incompatible with middleware schemas
"""
if user_state_schema is None:
return

# Check required fields exist
user_hints = get_type_hints(user_state_schema)
required_fields = {"messages"}

# remaining_steps is optional in some cases, but messages is always required
if "remaining_steps" not in user_hints:
# Allow missing remaining_steps, but warn about it
pass

if response_format is not None:
required_fields.add("structured_response")

missing_fields = required_fields - set(user_hints.keys())
if missing_fields:
raise ValueError(
f"User state_schema missing required fields: {missing_fields}. "
f"State schemas used with middleware must include: {required_fields}. "
f"Consider extending AgentState: class MyState(AgentState): ..."

Check failure on line 207 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.10) / Python 3.10

Ruff (EM102)

langchain/agents/middleware_agent.py:205:13: EM102 Exception must not use an f-string literal, assign to variable first

Check failure on line 207 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.13) / Python 3.13

Ruff (EM102)

langchain/agents/middleware_agent.py:205:13: EM102 Exception must not use an f-string literal, assign to variable first
)

Check failure on line 208 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.10) / Python 3.10

Ruff (TRY003)

langchain/agents/middleware_agent.py:204:15: TRY003 Avoid specifying long messages outside the exception class

Check failure on line 208 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.13) / Python 3.13

Ruff (TRY003)

langchain/agents/middleware_agent.py:204:15: TRY003 Avoid specifying long messages outside the exception class

# Check for field conflicts between user schema and middleware schemas
middleware_fields = set()
conflicting_middleware = []

for schema in state_schemas:
if schema != AgentState: # Skip base schema
schema_hints = get_type_hints(schema)
schema_fields = set(schema_hints.keys()) - {
"messages", "remaining_steps", "structured_response"
}

# Find which middleware defines these fields
for m in middleware:
if hasattr(m, 'state_schema') and m.state_schema == schema:

Check failure on line 223 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.10) / Python 3.10

Ruff (Q000)

langchain/agents/middleware_agent.py:223:31: Q000 Single quotes found but double quotes preferred

Check failure on line 223 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.13) / Python 3.13

Ruff (Q000)

langchain/agents/middleware_agent.py:223:31: Q000 Single quotes found but double quotes preferred
for field in schema_fields:
if field in user_hints:
conflicting_middleware.append((m.__class__.__name__, field))

Check failure on line 226 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.10) / Python 3.10

Ruff (PERF401)

langchain/agents/middleware_agent.py:226:29: PERF401 Use `list.extend` to create a transformed list

Check failure on line 226 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.13) / Python 3.13

Ruff (PERF401)

langchain/agents/middleware_agent.py:226:29: PERF401 Use `list.extend` to create a transformed list
break

middleware_fields.update(schema_fields)

user_fields = set(user_hints.keys()) - {
"messages", "remaining_steps", "structured_response"
}
conflicts = user_fields & middleware_fields

if conflicts:
conflict_details = []
for middleware_name, field in conflicting_middleware:
if field in conflicts:
conflict_details.append(f"'{field}' (defined by {middleware_name})")

raise ValueError(
f"State schema field conflicts detected: {conflicts}. "
f"These fields are already defined by middleware: {conflict_details}. "
f"Please use different field names or consider extending the middleware "
f"state schema instead."

Check failure on line 246 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.10) / Python 3.10

Ruff (EM102)

langchain/agents/middleware_agent.py:243:13: EM102 Exception must not use an f-string literal, assign to variable first

Check failure on line 246 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.13) / Python 3.13

Ruff (EM102)

langchain/agents/middleware_agent.py:243:13: EM102 Exception must not use an f-string literal, assign to variable first
)

Check failure on line 247 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.10) / Python 3.10

Ruff (TRY003)

langchain/agents/middleware_agent.py:242:15: TRY003 Avoid specifying long messages outside the exception class

Check failure on line 247 in libs/langchain_v1/langchain/agents/middleware_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.13) / Python 3.13

Ruff (TRY003)

langchain/agents/middleware_agent.py:242:15: TRY003 Avoid specifying long messages outside the exception class


def create_agent( # noqa: PLR0915
*,
model: str | BaseChatModel,
Expand All @@ -174,8 +255,9 @@
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
context_schema: type[ContextT] | None = None,
state_schema: type[StateT] | None = None,
) -> StateGraph[
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
StateT, ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
]:
"""Create a middleware agent graph."""
# init chat model
Expand Down Expand Up @@ -266,15 +348,22 @@
state_schemas = {m.state_schema for m in middleware}
state_schemas.add(AgentState)

state_schema = _resolve_schema(state_schemas, "StateSchema", None)
# NEW: Include user-provided state_schema in the merging process
if state_schema is not None:
# Validate compatibility before merging
_validate_schema_compatibility(state_schemas, state_schema, middleware, response_format)
state_schemas.add(state_schema)

# Resolve the final merged schemas
final_state_schema = _resolve_schema(state_schemas, "StateSchema", None)
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")

# create graph, add nodes
graph: StateGraph[
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
StateT, ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
] = StateGraph(
state_schema=state_schema,
state_schema=final_state_schema,
input_schema=input_schema,
output_schema=output_schema,
context_schema=context_schema,
Expand Down Expand Up @@ -501,7 +590,7 @@
)
before_node = RunnableCallable(sync_before, async_before)
graph.add_node(
f"{m.__class__.__name__}.before_model", before_node, input_schema=state_schema
f"{m.__class__.__name__}.before_model", before_node, input_schema=final_state_schema
)

if (
Expand All @@ -522,7 +611,7 @@
)
after_node = RunnableCallable(sync_after, async_after)
graph.add_node(
f"{m.__class__.__name__}.after_model", after_node, input_schema=state_schema
f"{m.__class__.__name__}.after_model", after_node, input_schema=final_state_schema
)

# add start edge
Expand Down Expand Up @@ -682,7 +771,7 @@


def _add_middleware_edge(
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
graph: StateGraph[StateT, ContextT, PublicAgentState, PublicAgentState],
name: str,
default_destination: str,
model_destination: str,
Expand Down
40 changes: 39 additions & 1 deletion libs/langchain_v1/langchain/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@
)


def create_agent( # noqa: D417

Check failure on line 907 in libs/langchain_v1/langchain/agents/react_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.10) / Python 3.10

Ruff (RUF100)

langchain/agents/react_agent.py:907:20: RUF100 Unused `noqa` directive (unused: `D417`)

Check failure on line 907 in libs/langchain_v1/langchain/agents/react_agent.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain_v1, 3.13) / Python 3.13

Ruff (RUF100)

langchain/agents/react_agent.py:907:20: RUF100 Unused `noqa` directive (unused: `D417`)
model: str | BaseChatModel | SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode,
*,
Expand Down Expand Up @@ -979,6 +979,23 @@
tools: A list of tools or a ToolNode instance.
If an empty list is provided, the agent will consist of a single LLM node
without tool calling.
middleware: A sequence of middleware instances for customizing agent behavior.
Each middleware can define its own state schema, which will be automatically
merged with any user-provided `state_schema`. Use middleware for features like
conversation summarization, planning, and other cross-cutting concerns.

!!! note "Middleware + State Schema Compatibility"
As of this version, `middleware` and `state_schema` can be used together:

```python
agent = create_agent(
model="openai:gpt-4",
tools=[tools],
middleware=[SummarizationMiddleware()], # For conversation management
state_schema=RAGState, # For custom state fields
)
```

prompt: An optional prompt for the LLM. Can take a few different forms:

- str: This is converted to a SystemMessage and added to the beginning
Expand Down Expand Up @@ -1065,6 +1082,26 @@
state_schema: An optional state schema that defines graph state.
Must have `messages` and `remaining_steps` keys.
Defaults to `AgentState` that defines those two keys.

!!! note "Using with Middleware"
You can now combine `state_schema` with `middleware` for advanced use cases:

```python
class RAGState(AgentState):
retrieved_documents: NotRequired[list[dict]]
citations: NotRequired[list[str]]

agent = create_agent(
model="openai:gpt-4",
tools=[retrieval_tool],
middleware=[SummarizationMiddleware(model=llm)],
state_schema=RAGState, # Now supported!
)
```

When used with middleware, the schema will be merged with middleware schemas.
Field conflicts between user schema and middleware schemas will raise
ValidationError.
context_schema: An optional schema for runtime context.
checkpointer: An optional checkpoint saver object. This is used for persisting
the state of the graph (e.g., as chat memory) for a single thread
Expand Down Expand Up @@ -1148,14 +1185,15 @@
assert not isinstance(response_format, tuple) # noqa: S101
assert pre_model_hook is None # noqa: S101
assert post_model_hook is None # noqa: S101
assert state_schema is None # noqa: S101
# REMOVED: assert state_schema is None # Allow state_schema with middleware
return create_middleware_agent( # type: ignore[return-value]
model=model,
tools=tools,
system_prompt=prompt,
middleware=middleware,
response_format=response_format,
context_schema=context_schema,
state_schema=state_schema, # Pass user state_schema to middleware agent
).compile(
checkpointer=checkpointer,
store=store,
Expand Down
Loading