diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index eb07787929124..a991cf540c82f 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -22,17 +22,20 @@ from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode from langgraph.runtime import Runtime # noqa: TC002 from langgraph.types import Command, Send -from langgraph.typing import ContextT # noqa: TC002 +from langgraph.typing import ContextT from typing_extensions import NotRequired, Required, TypedDict from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, + AsyncModelCallHandler, JumpTo, + ModelCallHandler, ModelRequest, ModelResponse, OmitFromSchema, ResponseT, + StateT, StateT_co, _InputAgentState, _OutputAgentState, @@ -86,13 +89,13 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp def _chain_model_call_handlers( handlers: Sequence[ Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]], ModelResponse | AIMessage, ] ], ) -> ( Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]], ModelResponse, ] | None @@ -140,8 +143,8 @@ def retry(req, state, runtime, handler): single_handler = handlers[0] def normalized_single( - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[StateT, ContextT], + handler: ModelCallHandler[StateT, ContextT], ) -> ModelResponse: result = single_handler(request, handler) return _normalize_to_model_response(result) @@ -150,25 +153,25 @@ def normalized_single( def compose_two( outer: Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]], ModelResponse | AIMessage, ], inner: Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]], ModelResponse | AIMessage, ], ) -> Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]], ModelResponse, ]: """Compose two handlers where outer wraps inner.""" def composed( - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[StateT, ContextT], + handler: ModelCallHandler[StateT, ContextT], ) -> ModelResponse: # Create a wrapper that calls inner with the base handler and normalizes - def inner_handler(req: ModelRequest) -> ModelResponse: + def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse: inner_result = inner(req, handler) return _normalize_to_model_response(inner_result) @@ -185,8 +188,8 @@ def inner_handler(req: ModelRequest) -> ModelResponse: # Wrap to ensure final return type is exactly ModelResponse def final_normalized( - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[StateT, ContextT], + handler: ModelCallHandler[StateT, ContextT], ) -> ModelResponse: # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes final_result = result(request, handler) @@ -198,13 +201,13 @@ def final_normalized( def _chain_async_model_call_handlers( handlers: Sequence[ Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]], Awaitable[ModelResponse | AIMessage], ] ], ) -> ( Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]], Awaitable[ModelResponse], ] | None @@ -225,8 +228,8 @@ def _chain_async_model_call_handlers( single_handler = handlers[0] async def normalized_single( - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[StateT, ContextT], + handler: AsyncModelCallHandler[StateT, ContextT], ) -> ModelResponse: result = await single_handler(request, handler) return _normalize_to_model_response(result) @@ -235,25 +238,25 @@ async def normalized_single( def compose_two( outer: Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]], Awaitable[ModelResponse | AIMessage], ], inner: Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]], Awaitable[ModelResponse | AIMessage], ], ) -> Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]], Awaitable[ModelResponse], ]: """Compose two async handlers where outer wraps inner.""" async def composed( - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[StateT, ContextT], + handler: AsyncModelCallHandler[StateT, ContextT], ) -> ModelResponse: # Create a wrapper that calls inner with the base handler and normalizes - async def inner_handler(req: ModelRequest) -> ModelResponse: + async def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse: inner_result = await inner(req, handler) return _normalize_to_model_response(inner_result) @@ -270,8 +273,8 @@ async def inner_handler(req: ModelRequest) -> ModelResponse: # Wrap to ensure final return type is exactly ModelResponse async def final_normalized( - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[StateT, ContextT], + handler: AsyncModelCallHandler[StateT, ContextT], ) -> ModelResponse: # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes final_result = await result(request, handler) @@ -546,9 +549,9 @@ def create_agent( # noqa: PLR0915 tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None, *, system_prompt: str | None = None, - middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (), response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None, state_schema: type[AgentState[ResponseT]] | None = None, + middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (), context_schema: type[ContextT] | None = None, checkpointer: Checkpointer | None = None, store: BaseStore | None = None, @@ -968,7 +971,9 @@ def _handle_model_output( return {"messages": [output]} - def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]: + def _get_bound_model( + request: ModelRequest[StateT, ContextT], + ) -> tuple[Runnable, ResponseFormat | None]: """Get the model with appropriate tool bindings. Performs auto-detection of strategy if needed based on model capabilities. @@ -1082,7 +1087,7 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | ) return request.model.bind(**request.model_settings), None - def _execute_model_sync(request: ModelRequest) -> ModelResponse: + def _execute_model_sync(request: ModelRequest[StateT, ContextT]) -> ModelResponse: """Execute model and return response. This is the core model execution logic wrapped by `wrap_model_call` handlers. @@ -1106,9 +1111,9 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse: structured_response=structured_response, ) - def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: + def model_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]: """Sync model request handler with sequential middleware processing.""" - request = ModelRequest( + request = ModelRequest[StateT, ContextT]( model=model, tools=default_tools, system_prompt=system_prompt, @@ -1133,7 +1138,7 @@ def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: return state_updates - async def _execute_model_async(request: ModelRequest) -> ModelResponse: + async def _execute_model_async(request: ModelRequest[StateT, ContextT]) -> ModelResponse: """Execute model asynchronously and return response. This is the core async model execution logic wrapped by `wrap_model_call` @@ -1159,9 +1164,9 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse: structured_response=structured_response, ) - async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: + async def amodel_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]: """Async model request handler with sequential middleware processing.""" - request = ModelRequest( + request = ModelRequest[StateT, ContextT]( model=model, tools=default_tools, system_prompt=system_prompt, diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 998fec8571655..816acb03d3bc9 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -45,12 +45,16 @@ __all__ = [ "AgentMiddleware", "AgentState", + "AsyncModelCallHandler", + "AsyncToolCallHandler", "ContextT", + "ModelCallHandler", "ModelRequest", "ModelResponse", "OmitFromSchema", "ResponseT", "StateT_co", + "ToolCallHandler", "ToolCallRequest", "ToolCallWrapper", "after_agent", @@ -68,6 +72,111 @@ ResponseT = TypeVar("ResponseT") +@dataclass +class ModelResponse: + """Response from model execution including messages and optional structured output. + + The result will usually contain a single `AIMessage`, but may include an additional + `ToolMessage` if the model used a tool for structured output. + """ + + result: list[BaseMessage] + """List of messages from model execution.""" + + structured_response: Any = None + """Parsed structured output if `response_format` was specified, `None` otherwise.""" + + +# Type alias for middleware return type - allows returning either full response or just AIMessage +ModelCallResult: TypeAlias = "ModelResponse | AIMessage" +"""`TypeAlias` for model call handler return value. + +Middleware can return either: + +- `ModelResponse`: Full response with messages and optional structured output +- `AIMessage`: Simplified return for simple use cases +""" + +# Type aliases for model call handlers +ModelCallHandler: TypeAlias = "Callable[[ModelRequest[StateT, ContextT]], ModelResponse]" +"""`TypeAlias` for synchronous model call handler callback. + +This is the handler function passed to `wrap_model_call` middleware that executes +the model request and returns a `ModelResponse`. +""" + +AsyncModelCallHandler: TypeAlias = ( + "Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]]" +) +"""`TypeAlias` for asynchronous model call handler callback. + +This is the handler function passed to `awrap_model_call` middleware that executes +the model request and returns an awaitable `ModelResponse`. +""" + +# Type aliases for tool call handlers +ToolCallHandler: TypeAlias = "Callable[[ToolCallRequest], ToolMessage | Command]" +"""`TypeAlias` for synchronous tool call handler callback. + +This is the handler function passed to `wrap_tool_call` middleware that executes +the tool call and returns a `ToolMessage` or `Command`. +""" + +AsyncToolCallHandler: TypeAlias = "Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]" +"""`TypeAlias` for asynchronous tool call handler callback. + +This is the handler function passed to `awrap_tool_call` middleware that executes +the tool call and returns an awaitable `ToolMessage` or `Command`. +""" + + +@dataclass +class OmitFromSchema: + """Annotation used to mark state attributes as omitted from input or output schemas.""" + + input: bool = True + """Whether to omit the attribute from the input schema.""" + + output: bool = True + """Whether to omit the attribute from the output schema.""" + + +OmitFromInput = OmitFromSchema(input=True, output=False) +"""Annotation used to mark state attributes as omitted from input schema.""" + +OmitFromOutput = OmitFromSchema(input=False, output=True) +"""Annotation used to mark state attributes as omitted from output schema.""" + +PrivateStateAttr = OmitFromSchema(input=True, output=True) +"""Annotation used to mark state attributes as purely internal for a given middleware.""" + + +class AgentState(TypedDict, Generic[ResponseT]): + """State schema for the agent.""" + + messages: Required[Annotated[list[AnyMessage], add_messages]] + jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]] + structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]] + + +class _InputAgentState(TypedDict): # noqa: PYI049 + """Input state schema for the agent.""" + + messages: Required[Annotated[list[AnyMessage | dict], add_messages]] + + +class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049 + """Output state schema for the agent.""" + + messages: Required[Annotated[list[AnyMessage], add_messages]] + structured_response: NotRequired[ResponseT] + + +StateT = TypeVar("StateT", bound=AgentState, default=AgentState) +StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True) +StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True) + + class _ModelRequestOverrides(TypedDict, total=False): """Possible overrides for `ModelRequest.override()` method.""" @@ -81,8 +190,11 @@ class _ModelRequestOverrides(TypedDict, total=False): @dataclass -class ModelRequest: - """Model request information for the agent.""" +class ModelRequest(Generic[StateT, ContextT]): + """Model request information for the agent. + + Generic over `ContextT` for better type inference of the runtime context. + """ model: BaseChatModel system_prompt: str | None @@ -90,8 +202,8 @@ class ModelRequest: tool_choice: Any | None tools: list[BaseTool | dict] response_format: ResponseFormat | None - state: AgentState - runtime: Runtime[ContextT] # type: ignore[valid-type] + state: StateT + runtime: Runtime[ContextT] model_settings: dict[str, Any] = field(default_factory=dict) def __setattr__(self, name: str, value: Any) -> None: @@ -119,7 +231,9 @@ def __setattr__(self, name: str, value: Any) -> None: ) object.__setattr__(self, name, value) - def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest: + def override( + self, **overrides: Unpack[_ModelRequestOverrides] + ) -> ModelRequest[StateT, ContextT]: """Replace the request with a new request with the given overrides. Returns a new `ModelRequest` instance with the specified attributes replaced. @@ -158,79 +272,6 @@ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest: return replace(self, **overrides) -@dataclass -class ModelResponse: - """Response from model execution including messages and optional structured output. - - The result will usually contain a single `AIMessage`, but may include an additional - `ToolMessage` if the model used a tool for structured output. - """ - - result: list[BaseMessage] - """List of messages from model execution.""" - - structured_response: Any = None - """Parsed structured output if `response_format` was specified, `None` otherwise.""" - - -# Type alias for middleware return type - allows returning either full response or just AIMessage -ModelCallResult: TypeAlias = "ModelResponse | AIMessage" -"""`TypeAlias` for model call handler return value. - -Middleware can return either: - -- `ModelResponse`: Full response with messages and optional structured output -- `AIMessage`: Simplified return for simple use cases -""" - - -@dataclass -class OmitFromSchema: - """Annotation used to mark state attributes as omitted from input or output schemas.""" - - input: bool = True - """Whether to omit the attribute from the input schema.""" - - output: bool = True - """Whether to omit the attribute from the output schema.""" - - -OmitFromInput = OmitFromSchema(input=True, output=False) -"""Annotation used to mark state attributes as omitted from input schema.""" - -OmitFromOutput = OmitFromSchema(input=False, output=True) -"""Annotation used to mark state attributes as omitted from output schema.""" - -PrivateStateAttr = OmitFromSchema(input=True, output=True) -"""Annotation used to mark state attributes as purely internal for a given middleware.""" - - -class AgentState(TypedDict, Generic[ResponseT]): - """State schema for the agent.""" - - messages: Required[Annotated[list[AnyMessage], add_messages]] - jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]] - structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]] - - -class _InputAgentState(TypedDict): # noqa: PYI049 - """Input state schema for the agent.""" - - messages: Required[Annotated[list[AnyMessage | dict], add_messages]] - - -class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049 - """Output state schema for the agent.""" - - messages: Required[Annotated[list[AnyMessage], add_messages]] - structured_response: NotRequired[ResponseT] - - -StateT = TypeVar("StateT", bound=AgentState, default=AgentState) -StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True) -StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True) - - class AgentMiddleware(Generic[StateT, ContextT]): """Base middleware class for an agent. @@ -287,8 +328,8 @@ async def aafter_model( def wrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[StateT, ContextT], + handler: ModelCallHandler[StateT, ContextT], ) -> ModelCallResult: """Intercept and control model execution via handler callback. @@ -382,8 +423,8 @@ def wrap_model_call(self, request, handler): async def awrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[StateT, ContextT], + handler: AsyncModelCallHandler[StateT, ContextT], ) -> ModelCallResult: """Intercept and control async model execution via handler callback. @@ -443,7 +484,7 @@ async def aafter_agent( def wrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], + handler: ToolCallHandler, ) -> ToolMessage | Command: """Intercept tool execution for retries, monitoring, or modification. @@ -525,7 +566,7 @@ def wrap_tool_call(self, request, handler): async def awrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + handler: AsyncToolCallHandler, ) -> ToolMessage | Command: """Intercept and control async tool execution via handler callback. @@ -605,7 +646,7 @@ def __call__( class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc] """Callable that returns a prompt string given `ModelRequest` (contains state and runtime).""" - def __call__(self, request: ModelRequest) -> str | Awaitable[str]: + def __call__(self, request: ModelRequest[StateT_contra, ContextT]) -> str | Awaitable[str]: """Generate a system prompt string based on the request.""" ... @@ -619,8 +660,8 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ def __call__( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[StateT_contra, ContextT], + handler: ModelCallHandler[StateT_contra, ContextT], ) -> ModelCallResult: """Intercept model execution via handler callback.""" ... @@ -636,7 +677,7 @@ class _CallableReturningToolResponse(Protocol): def __call__( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], + handler: ToolCallHandler, ) -> ToolMessage | Command: """Intercept tool execution via handler callback.""" ... @@ -1365,8 +1406,8 @@ def decorator( async def async_wrapped( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[StateT, ContextT], + handler: AsyncModelCallHandler[StateT, ContextT], ) -> ModelCallResult: prompt = await func(request) # type: ignore[misc] request = request.override(system_prompt=prompt) @@ -1386,8 +1427,8 @@ async def async_wrapped( def wrapped( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[StateT, ContextT], + handler: ModelCallHandler[StateT, ContextT], ) -> ModelCallResult: prompt = cast("str", func(request)) request = request.override(system_prompt=prompt) @@ -1395,8 +1436,8 @@ def wrapped( async def async_wrapped_from_sync( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[StateT, ContextT], + handler: AsyncModelCallHandler[StateT, ContextT], ) -> ModelCallResult: # Delegate to sync function prompt = cast("str", func(request)) @@ -1537,8 +1578,8 @@ def decorator( async def async_wrapped( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[StateT, ContextT], + handler: AsyncModelCallHandler[StateT, ContextT], ) -> ModelCallResult: return await func(request, handler) # type: ignore[misc, arg-type] @@ -1558,8 +1599,8 @@ async def async_wrapped( def wrapped( self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[StateT, ContextT], + handler: ModelCallHandler[StateT, ContextT], ) -> ModelCallResult: return func(request, handler) @@ -1698,7 +1739,7 @@ def decorator( async def async_wrapped( self: AgentMiddleware, # noqa: ARG001 request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + handler: AsyncToolCallHandler, ) -> ToolMessage | Command: return await func(request, handler) # type: ignore[arg-type,misc] @@ -1719,7 +1760,7 @@ async def async_wrapped( def wrapped( self: AgentMiddleware, # noqa: ARG001 request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], + handler: ToolCallHandler, ) -> ToolMessage | Command: return func(request, handler)