diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 17c2825d0e7ca..092fa084668ed 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -22,6 +22,7 @@ ModelRequest, OmitFromSchema, PublicAgentState, + StateT, ) from langchain.agents.structured_output import ( MultipleStructuredOutputsError, @@ -166,6 +167,86 @@ def _handle_structured_output_error( return False, "" +def _validate_schema_compatibility( + state_schemas: set[type], + user_state_schema: type | None, + middleware: Sequence[AgentMiddleware], + response_format: ResponseFormat | type | None, +) -> None: + """Validate that user state schema is compatible with middleware schemas. + + Args: + state_schemas: Set of middleware state schemas + user_state_schema: User-provided state schema to validate + middleware: List of middleware instances for error messaging + response_format: Response format to check if structured_response field is required + + Raises: + ValueError: If user schema is incompatible with middleware schemas + """ + if user_state_schema is None: + return + + # Check required fields exist + user_hints = get_type_hints(user_state_schema) + required_fields = {"messages"} + + # remaining_steps is optional in some cases, but messages is always required + if "remaining_steps" not in user_hints: + # Allow missing remaining_steps, but warn about it + pass + + if response_format is not None: + required_fields.add("structured_response") + + missing_fields = required_fields - set(user_hints.keys()) + if missing_fields: + raise ValueError( + f"User state_schema missing required fields: {missing_fields}. " + f"State schemas used with middleware must include: {required_fields}. " + f"Consider extending AgentState: class MyState(AgentState): ..." + ) + + # Check for field conflicts between user schema and middleware schemas + middleware_fields = set() + conflicting_middleware = [] + + for schema in state_schemas: + if schema != AgentState: # Skip base schema + schema_hints = get_type_hints(schema) + schema_fields = set(schema_hints.keys()) - { + "messages", "remaining_steps", "structured_response" + } + + # Find which middleware defines these fields + for m in middleware: + if hasattr(m, 'state_schema') and m.state_schema == schema: + for field in schema_fields: + if field in user_hints: + conflicting_middleware.append((m.__class__.__name__, field)) + break + + middleware_fields.update(schema_fields) + + user_fields = set(user_hints.keys()) - { + "messages", "remaining_steps", "structured_response" + } + conflicts = user_fields & middleware_fields + + if conflicts: + conflict_details = [] + for middleware_name, field in conflicting_middleware: + if field in conflicts: + conflict_details.append(f"'{field}' (defined by {middleware_name})") + + raise ValueError( + f"State schema field conflicts detected: {conflicts}. " + f"These fields are already defined by middleware: {conflict_details}. " + f"Please use different field names or consider extending the middleware " + f"state schema instead." + ) + + def create_agent( # noqa: PLR0915 *, model: str | BaseChatModel, @@ -174,8 +255,9 @@ def create_agent( # noqa: PLR0915 middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (), response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None, context_schema: type[ContextT] | None = None, + state_schema: type[StateT] | None = None, ) -> StateGraph[ - AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] + StateT, ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] ]: """Create a middleware agent graph.""" # init chat model @@ -266,15 +348,22 @@ def create_agent( # noqa: PLR0915 state_schemas = {m.state_schema for m in middleware} state_schemas.add(AgentState) - state_schema = _resolve_schema(state_schemas, "StateSchema", None) + # NEW: Include user-provided state_schema in the merging process + if state_schema is not None: + # Validate compatibility before merging + _validate_schema_compatibility(state_schemas, state_schema, middleware, response_format) + state_schemas.add(state_schema) + + # Resolve the final merged schemas + final_state_schema = _resolve_schema(state_schemas, "StateSchema", None) input_schema = _resolve_schema(state_schemas, "InputSchema", "input") output_schema = _resolve_schema(state_schemas, "OutputSchema", "output") # create graph, add nodes graph: StateGraph[ - AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] + StateT, ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] ] = StateGraph( - state_schema=state_schema, + state_schema=final_state_schema, input_schema=input_schema, output_schema=output_schema, context_schema=context_schema, @@ -501,7 +590,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ ) before_node = RunnableCallable(sync_before, async_before) graph.add_node( - f"{m.__class__.__name__}.before_model", before_node, input_schema=state_schema + f"{m.__class__.__name__}.before_model", before_node, input_schema=final_state_schema ) if ( @@ -522,7 +611,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[ ) after_node = RunnableCallable(sync_after, async_after) graph.add_node( - f"{m.__class__.__name__}.after_model", after_node, input_schema=state_schema + f"{m.__class__.__name__}.after_model", after_node, input_schema=final_state_schema ) # add start edge @@ -682,7 +771,7 @@ def tools_to_model(state: dict[str, Any]) -> str | None: def _add_middleware_edge( - graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState], + graph: StateGraph[StateT, ContextT, PublicAgentState, PublicAgentState], name: str, default_destination: str, model_destination: str, diff --git a/libs/langchain_v1/langchain/agents/react_agent.py b/libs/langchain_v1/langchain/agents/react_agent.py index cd1cc469c7e0f..1449ef6a268d6 100644 --- a/libs/langchain_v1/langchain/agents/react_agent.py +++ b/libs/langchain_v1/langchain/agents/react_agent.py @@ -979,6 +979,23 @@ def select_model(state: AgentState, runtime: Runtime[ModelContext]) -> ChatOpenA tools: A list of tools or a ToolNode instance. If an empty list is provided, the agent will consist of a single LLM node without tool calling. + middleware: A sequence of middleware instances for customizing agent behavior. + Each middleware can define its own state schema, which will be automatically + merged with any user-provided `state_schema`. Use middleware for features like + conversation summarization, planning, and other cross-cutting concerns. + + !!! note "Middleware + State Schema Compatibility" + As of this version, `middleware` and `state_schema` can be used together: + + ```python + agent = create_agent( + model="openai:gpt-4", + tools=[tools], + middleware=[SummarizationMiddleware()], # For conversation management + state_schema=RAGState, # For custom state fields + ) + ``` + prompt: An optional prompt for the LLM. Can take a few different forms: - str: This is converted to a SystemMessage and added to the beginning @@ -1065,6 +1082,26 @@ def select_model(state: AgentState, runtime: Runtime[ModelContext]) -> ChatOpenA state_schema: An optional state schema that defines graph state. Must have `messages` and `remaining_steps` keys. Defaults to `AgentState` that defines those two keys. + + !!! note "Using with Middleware" + You can now combine `state_schema` with `middleware` for advanced use cases: + + ```python + class RAGState(AgentState): + retrieved_documents: NotRequired[list[dict]] + citations: NotRequired[list[str]] + + agent = create_agent( + model="openai:gpt-4", + tools=[retrieval_tool], + middleware=[SummarizationMiddleware(model=llm)], + state_schema=RAGState, # Now supported! + ) + ``` + + When used with middleware, the schema will be merged with middleware schemas. + Field conflicts between user schema and middleware schemas will raise + ValidationError. context_schema: An optional schema for runtime context. checkpointer: An optional checkpoint saver object. This is used for persisting the state of the graph (e.g., as chat memory) for a single thread @@ -1148,7 +1185,7 @@ def check_weather(location: str) -> str: assert not isinstance(response_format, tuple) # noqa: S101 assert pre_model_hook is None # noqa: S101 assert post_model_hook is None # noqa: S101 - assert state_schema is None # noqa: S101 + # REMOVED: assert state_schema is None # Allow state_schema with middleware return create_middleware_agent( # type: ignore[return-value] model=model, tools=tools, @@ -1156,6 +1193,7 @@ def check_weather(location: str) -> str: middleware=middleware, response_format=response_format, context_schema=context_schema, + state_schema=state_schema, # Pass user state_schema to middleware agent ).compile( checkpointer=checkpointer, store=store,