From 19d763ac2fa8fb2b91153b1a0f028550b93b9100 Mon Sep 17 00:00:00 2001 From: rohanvasudev1 Date: Fri, 3 Oct 2025 15:23:02 -0700 Subject: [PATCH 1/2] issue #33217 enable middleware and state_schema compatibility Remove artificial constraint preventing middleware and state_schema from being used together in create_agent(). This enables RAG applications that need both conversation summarization middleware and custom state for document storage. - Remove assert state_schema is None constraint in react_agent.py - Add state_schema parameter to middleware_agent.create_agent() - Implement schema validation and automatic merging - Add comprehensive documentation with RAG examples - Maintain full backwards compatibility --- .../langchain/agents/middleware_agent.py | 98 +++++++++++++++++-- .../langchain/agents/react_agent.py | 39 +++++++- 2 files changed, 129 insertions(+), 8 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index ce752aefb65cf..d94f02920b825 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -22,6 +22,7 @@ ModelRequest, OmitFromSchema, PublicAgentState, + StateT, ) from langchain.agents.structured_output import ( MultipleStructuredOutputsError, @@ -166,6 +167,81 @@ def _handle_structured_output_error( return False, "" +def _validate_schema_compatibility( + state_schemas: set[type], + user_state_schema: type | None, + middleware: Sequence[AgentMiddleware], + response_format: ResponseFormat | type | None, +) -> None: + """Validate that user state schema is compatible with middleware schemas. + + Args: + state_schemas: Set of middleware state schemas + user_state_schema: User-provided state schema to validate + middleware: List of middleware instances for error messaging + response_format: Response format to check if structured_response field is required + + Raises: + ValueError: If user schema is incompatible with middleware schemas + """ + if user_state_schema is None: + return + + # Check required fields exist + user_hints = get_type_hints(user_state_schema) + required_fields = {"messages"} + + # remaining_steps is optional in some cases, but messages is always required + if "remaining_steps" not in user_hints: + # Allow missing remaining_steps, but warn about it + pass + + if response_format is not None: + required_fields.add("structured_response") + + missing_fields = required_fields - set(user_hints.keys()) + if missing_fields: + raise ValueError( + f"User state_schema missing required fields: {missing_fields}. " + f"State schemas used with middleware must include: {required_fields}. " + f"Consider extending AgentState: class MyState(AgentState): ..." + ) + + # Check for field conflicts between user schema and middleware schemas + middleware_fields = set() + conflicting_middleware = [] + + for schema in state_schemas: + if schema != AgentState: # Skip base schema + schema_hints = get_type_hints(schema) + schema_fields = set(schema_hints.keys()) - {"messages", "remaining_steps", "structured_response"} + + # Find which middleware defines these fields + for m in middleware: + if hasattr(m, 'state_schema') and m.state_schema == schema: + for field in schema_fields: + if field in user_hints: + conflicting_middleware.append((m.__class__.__name__, field)) + break + + middleware_fields.update(schema_fields) + + user_fields = set(user_hints.keys()) - {"messages", "remaining_steps", "structured_response"} + conflicts = user_fields & middleware_fields + + if conflicts: + conflict_details = [] + for middleware_name, field in conflicting_middleware: + if field in conflicts: + conflict_details.append(f"'{field}' (defined by {middleware_name})") + + raise ValueError( + f"State schema field conflicts detected: {conflicts}. " + f"These fields are already defined by middleware: {conflict_details}. " + f"Please use different field names or consider extending the middleware state schema instead." + ) + + def create_agent( # noqa: PLR0915 *, model: str | BaseChatModel, @@ -174,8 +250,9 @@ def create_agent( # noqa: PLR0915 middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (), response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None, context_schema: type[ContextT] | None = None, + state_schema: type[StateT] | None = None, ) -> StateGraph[ - AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] + StateT, ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] ]: """Create a middleware agent graph.""" # init chat model @@ -266,15 +343,22 @@ def create_agent( # noqa: PLR0915 state_schemas = {m.state_schema for m in middleware} state_schemas.add(AgentState) - state_schema = _resolve_schema(state_schemas, "StateSchema", None) + # NEW: Include user-provided state_schema in the merging process + if state_schema is not None: + # Validate compatibility before merging + _validate_schema_compatibility(state_schemas, state_schema, middleware, response_format) + state_schemas.add(state_schema) + + # Resolve the final merged schemas + final_state_schema = _resolve_schema(state_schemas, "StateSchema", None) input_schema = _resolve_schema(state_schemas, "InputSchema", "input") output_schema = _resolve_schema(state_schemas, "OutputSchema", "output") # create graph, add nodes graph: StateGraph[ - AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] + StateT, ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] ] = StateGraph( - state_schema=state_schema, + state_schema=final_state_schema, input_schema=input_schema, output_schema=output_schema, context_schema=context_schema, @@ -493,7 +577,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ ) before_node = RunnableCallable(sync_before, async_before) graph.add_node( - f"{m.__class__.__name__}.before_model", before_node, input_schema=state_schema + f"{m.__class__.__name__}.before_model", before_node, input_schema=final_state_schema ) if ( @@ -514,7 +598,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ ) after_node = RunnableCallable(sync_after, async_after) graph.add_node( - f"{m.__class__.__name__}.after_model", after_node, input_schema=state_schema + f"{m.__class__.__name__}.after_model", after_node, input_schema=final_state_schema ) # add start edge @@ -674,7 +758,7 @@ def tools_to_model(state: dict[str, Any]) -> str | None: def _add_middleware_edge( - graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState], + graph: StateGraph[StateT, ContextT, PublicAgentState, PublicAgentState], name: str, default_destination: str, model_destination: str, diff --git a/libs/langchain_v1/langchain/agents/react_agent.py b/libs/langchain_v1/langchain/agents/react_agent.py index cd1cc469c7e0f..8dddfdf958d72 100644 --- a/libs/langchain_v1/langchain/agents/react_agent.py +++ b/libs/langchain_v1/langchain/agents/react_agent.py @@ -979,6 +979,23 @@ def select_model(state: AgentState, runtime: Runtime[ModelContext]) -> ChatOpenA tools: A list of tools or a ToolNode instance. If an empty list is provided, the agent will consist of a single LLM node without tool calling. + middleware: A sequence of middleware instances for customizing agent behavior. + Each middleware can define its own state schema, which will be automatically + merged with any user-provided `state_schema`. Use middleware for features like + conversation summarization, planning, and other cross-cutting concerns. + + !!! note "Middleware + State Schema Compatibility" + As of this version, `middleware` and `state_schema` can be used together: + + ```python + agent = create_agent( + model="openai:gpt-4", + tools=[tools], + middleware=[SummarizationMiddleware()], # For conversation management + state_schema=RAGState, # For custom state fields + ) + ``` + prompt: An optional prompt for the LLM. Can take a few different forms: - str: This is converted to a SystemMessage and added to the beginning @@ -1065,6 +1082,25 @@ def select_model(state: AgentState, runtime: Runtime[ModelContext]) -> ChatOpenA state_schema: An optional state schema that defines graph state. Must have `messages` and `remaining_steps` keys. Defaults to `AgentState` that defines those two keys. + + !!! note "Using with Middleware" + You can now combine `state_schema` with `middleware` for advanced use cases: + + ```python + class RAGState(AgentState): + retrieved_documents: NotRequired[list[dict]] + citations: NotRequired[list[str]] + + agent = create_agent( + model="openai:gpt-4", + tools=[retrieval_tool], + middleware=[SummarizationMiddleware(model=llm)], + state_schema=RAGState, # Now supported! + ) + ``` + + When used with middleware, the schema will be merged with middleware schemas. + Field conflicts between user schema and middleware schemas will raise ValidationError. context_schema: An optional schema for runtime context. checkpointer: An optional checkpoint saver object. This is used for persisting the state of the graph (e.g., as chat memory) for a single thread @@ -1148,7 +1184,7 @@ def check_weather(location: str) -> str: assert not isinstance(response_format, tuple) # noqa: S101 assert pre_model_hook is None # noqa: S101 assert post_model_hook is None # noqa: S101 - assert state_schema is None # noqa: S101 + # REMOVED: assert state_schema is None # Allow state_schema with middleware return create_middleware_agent( # type: ignore[return-value] model=model, tools=tools, @@ -1156,6 +1192,7 @@ def check_weather(location: str) -> str: middleware=middleware, response_format=response_format, context_schema=context_schema, + state_schema=state_schema, # Pass user state_schema to middleware agent ).compile( checkpointer=checkpointer, store=store, From c86144bada21cc320b2e878070ec8ba563e0e208 Mon Sep 17 00:00:00 2001 From: rohanvasudev1 Date: Fri, 3 Oct 2025 16:46:19 -0700 Subject: [PATCH 2/2] Fixing linting issues --- .../langchain_v1/langchain/agents/middleware_agent.py | 11 ++++++++--- libs/langchain_v1/langchain/agents/react_agent.py | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index d94f02920b825..aa7ad6a7fdbca 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -214,7 +214,9 @@ def _validate_schema_compatibility( for schema in state_schemas: if schema != AgentState: # Skip base schema schema_hints = get_type_hints(schema) - schema_fields = set(schema_hints.keys()) - {"messages", "remaining_steps", "structured_response"} + schema_fields = set(schema_hints.keys()) - { + "messages", "remaining_steps", "structured_response" + } # Find which middleware defines these fields for m in middleware: @@ -226,7 +228,9 @@ def _validate_schema_compatibility( middleware_fields.update(schema_fields) - user_fields = set(user_hints.keys()) - {"messages", "remaining_steps", "structured_response"} + user_fields = set(user_hints.keys()) - { + "messages", "remaining_steps", "structured_response" + } conflicts = user_fields & middleware_fields if conflicts: @@ -238,7 +242,8 @@ def _validate_schema_compatibility( raise ValueError( f"State schema field conflicts detected: {conflicts}. " f"These fields are already defined by middleware: {conflict_details}. " - f"Please use different field names or consider extending the middleware state schema instead." + f"Please use different field names or consider extending the middleware " + f"state schema instead." ) diff --git a/libs/langchain_v1/langchain/agents/react_agent.py b/libs/langchain_v1/langchain/agents/react_agent.py index 8dddfdf958d72..1449ef6a268d6 100644 --- a/libs/langchain_v1/langchain/agents/react_agent.py +++ b/libs/langchain_v1/langchain/agents/react_agent.py @@ -1100,7 +1100,8 @@ class RAGState(AgentState): ``` When used with middleware, the schema will be merged with middleware schemas. - Field conflicts between user schema and middleware schemas will raise ValidationError. + Field conflicts between user schema and middleware schemas will raise + ValidationError. context_schema: An optional schema for runtime context. checkpointer: An optional checkpoint saver object. This is used for persisting the state of the graph (e.g., as chat memory) for a single thread