From 4689cdae2cc3fa14ba02116f43ca67b6392c3495 Mon Sep 17 00:00:00 2001 From: yanglikun Date: Fri, 21 Nov 2025 11:30:19 +0800 Subject: [PATCH] fix: correct imports and create_agent construction --- src/oss/langchain/middleware/custom.mdx | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/oss/langchain/middleware/custom.mdx b/src/oss/langchain/middleware/custom.mdx index 318cc5d784..aa8982d9c9 100644 --- a/src/oss/langchain/middleware/custom.mdx +++ b/src/oss/langchain/middleware/custom.mdx @@ -373,15 +373,19 @@ Middleware can extend the agent's state with custom properties. ```python +from langchain.agents import create_agent +from langchain_core.messages import HumanMessage from langchain.agents.middleware import AgentState, before_model, after_model from typing_extensions import NotRequired from typing import Any from langgraph.runtime import Runtime + class CustomState(AgentState): model_call_count: NotRequired[int] user_id: NotRequired[str] + @before_model(state_schema=CustomState, can_jump_to=["end"]) def check_call_limit(state: CustomState, runtime: Runtime) -> dict[str, Any] | None: count = state.get("model_call_count", 0) @@ -389,14 +393,16 @@ def check_call_limit(state: CustomState, runtime: Runtime) -> dict[str, Any] | N return {"jump_to": "end"} return None + @after_model(state_schema=CustomState) def increment_counter(state: CustomState, runtime: Runtime) -> dict[str, Any] | None: return {"model_call_count": state.get("model_call_count", 0) + 1} + agent = create_agent( model="gpt-4o", middleware=[check_call_limit, increment_counter], - tools=[...], + tools=[], ) # Invoke with custom state @@ -412,14 +418,18 @@ result = agent.invoke({ ```python +from langchain.agents import create_agent +from langchain_core.messages import HumanMessage from langchain.agents.middleware import AgentState, AgentMiddleware from typing_extensions import NotRequired from typing import Any + class CustomState(AgentState): model_call_count: NotRequired[int] user_id: NotRequired[str] + class CallCounterMiddleware(AgentMiddleware[CustomState]): state_schema = CustomState @@ -432,10 +442,11 @@ class CallCounterMiddleware(AgentMiddleware[CustomState]): def after_model(self, state: CustomState, runtime) -> dict[str, Any] | None: return {"model_call_count": state.get("model_call_count", 0) + 1} + agent = create_agent( model="gpt-4o", middleware=[CallCounterMiddleware()], - tools=[...], + tools=[], ) # Invoke with custom state