Skip to content

Commit 4d11877

Browse files
feat(langchain): dynamic system prompt middleware (#33006)
# Changes ## Adds support for `DynamicSystemPromptMiddleware` ```py from langchain.agents.middleware import DynamicSystemPromptMiddleware from langgraph.runtime import Runtime from typing_extensions import TypedDict class Context(TypedDict): user_name: str def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str: user_name = runtime.context.get("user_name", "n/a") return f"You are a helpful assistant. Always address the user by their name: {user_name}" middleware = DynamicSystemPromptMiddleware(system_prompt) ``` ## Adds support for `runtime` in middleware hooks ```py class AgentMiddleware(Generic[StateT, ContextT]): def modify_model_request( self, request: ModelRequest, state: StateT, runtime: Runtime[ContextT], # Optional runtime parameter ) -> ModelRequest: # upgrade model if runtime.context.subscription is `top-tier` or whatever ``` ## Adds support for omitting state attributes from input / output schemas ```py from typing import Annotated, NotRequired from langchain.agents.middleware.types import PrivateStateAttr, OmitFromInput, OmitFromOutput class CustomState(AgentState): # Private field - not in input or output schemas internal_counter: NotRequired[Annotated[int, PrivateStateAttr]] # Input-only field - not in output schema user_input: NotRequired[Annotated[str, OmitFromOutput]] # Output-only field - not in input schema computed_result: NotRequired[Annotated[str, OmitFromInput]] ``` ## Additionally * Removes filtering of state before passing into middleware hooks Typing is not foolproof here, still need to figure out some of the generics stuff w/ state and context schema extensions for middleware. TODO: * More docs for middleware, should hold off on this until other prios like MCP and deepagents are met --------- Co-authored-by: Eugene Yurtsev <[email protected]>
1 parent f158cea commit 4d11877

File tree

8 files changed

+370
-52
lines changed

8 files changed

+370
-52
lines changed

libs/langchain_v1/langchain/agents/middleware/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Middleware plugins for agents."""
22

3+
from .dynamic_system_prompt import DynamicSystemPromptMiddleware
34
from .human_in_the_loop import HumanInTheLoopMiddleware
45
from .prompt_caching import AnthropicPromptCachingMiddleware
56
from .summarization import SummarizationMiddleware
@@ -8,7 +9,9 @@
89
__all__ = [
910
"AgentMiddleware",
1011
"AgentState",
12+
# should move to langchain-anthropic if we decide to keep it
1113
"AnthropicPromptCachingMiddleware",
14+
"DynamicSystemPromptMiddleware",
1215
"HumanInTheLoopMiddleware",
1316
"ModelRequest",
1417
"SummarizationMiddleware",
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Dynamic System Prompt Middleware.
2+
3+
Allows setting the system prompt dynamically right before each model invocation.
4+
Useful when the prompt depends on the current agent state or per-invocation context.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from inspect import signature
10+
from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
11+
12+
from langgraph.typing import ContextT
13+
14+
from langchain.agents.middleware.types import (
15+
AgentMiddleware,
16+
AgentState,
17+
ModelRequest,
18+
)
19+
20+
if TYPE_CHECKING:
21+
from langgraph.runtime import Runtime
22+
23+
24+
class DynamicSystemPromptWithoutRuntime(Protocol):
25+
"""Dynamic system prompt without runtime in call signature."""
26+
27+
def __call__(self, state: AgentState) -> str:
28+
"""Return the system prompt for the next model call."""
29+
...
30+
31+
32+
class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
33+
"""Dynamic system prompt with runtime in call signature."""
34+
35+
def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
36+
"""Return the system prompt for the next model call."""
37+
...
38+
39+
40+
DynamicSystemPrompt: TypeAlias = (
41+
DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
42+
)
43+
44+
45+
class DynamicSystemPromptMiddleware(AgentMiddleware):
46+
"""Dynamic System Prompt Middleware.
47+
48+
Allows setting the system prompt dynamically right before each model invocation.
49+
Useful when the prompt depends on the current agent state or per-invocation context.
50+
51+
Example:
52+
```python
53+
from langchain.agents.middleware import DynamicSystemPromptMiddleware
54+
55+
56+
class Context(TypedDict):
57+
user_name: str
58+
59+
60+
def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
61+
user_name = runtime.context.get("user_name", "n/a")
62+
return (
63+
f"You are a helpful assistant. Always address the user by their name: {user_name}"
64+
)
65+
66+
67+
middleware = DynamicSystemPromptMiddleware(system_prompt)
68+
```
69+
"""
70+
71+
_accepts_runtime: bool
72+
73+
def __init__(
74+
self,
75+
dynamic_system_prompt: DynamicSystemPrompt[ContextT],
76+
) -> None:
77+
"""Initialize the dynamic system prompt middleware.
78+
79+
Args:
80+
dynamic_system_prompt: Function that receives the current agent state
81+
and optionally runtime with context, and returns the system prompt for
82+
the next model call. Returns a string.
83+
"""
84+
super().__init__()
85+
self.dynamic_system_prompt = dynamic_system_prompt
86+
self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
87+
88+
def modify_model_request(
89+
self,
90+
request: ModelRequest,
91+
state: AgentState,
92+
runtime: Runtime[ContextT],
93+
) -> ModelRequest:
94+
"""Modify the model request to include the dynamic system prompt."""
95+
if self._accepts_runtime:
96+
system_prompt = cast(
97+
"DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
98+
)(state, runtime)
99+
else:
100+
system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
101+
state
102+
)
103+
104+
request.system_prompt = system_prompt
105+
return request

libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(
143143
self.tool_configs = resolved_tool_configs
144144
self.description_prefix = description_prefix
145145

146-
def after_model(self, state: AgentState) -> dict[str, Any] | None:
146+
def after_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
147147
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
148148
messages = state["messages"]
149149
if not messages:

libs/langchain_v1/langchain/agents/middleware/prompt_caching.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Literal
44

5-
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
5+
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
66

77

88
class AnthropicPromptCachingMiddleware(AgentMiddleware):
@@ -32,7 +32,10 @@ def __init__(
3232
self.ttl = ttl
3333
self.min_messages_to_cache = min_messages_to_cache
3434

35-
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002
35+
def modify_model_request( # type: ignore[override]
36+
self,
37+
request: ModelRequest,
38+
) -> ModelRequest:
3639
"""Modify the model request to add cache control blocks."""
3740
try:
3841
from langchain_anthropic import ChatAnthropic

libs/langchain_v1/langchain/agents/middleware/summarization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
self.summary_prompt = summary_prompt
9999
self.summary_prefix = summary_prefix
100100

101-
def before_model(self, state: AgentState) -> dict[str, Any] | None:
101+
def before_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
102102
"""Process messages before model invocation, potentially triggering summarization."""
103103
messages = state["messages"]
104104
self._ensure_message_ids(messages)

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,27 @@
88
# needed as top level import for pydantic schema generation on AgentState
99
from langchain_core.messages import AnyMessage # noqa: TC002
1010
from langgraph.channels.ephemeral_value import EphemeralValue
11-
from langgraph.graph.message import Messages, add_messages
11+
from langgraph.graph.message import add_messages
12+
from langgraph.runtime import Runtime
13+
from langgraph.typing import ContextT
1214
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
1315

1416
if TYPE_CHECKING:
1517
from langchain_core.language_models.chat_models import BaseChatModel
1618
from langchain_core.tools import BaseTool
19+
from langgraph.runtime import Runtime
1720

1821
from langchain.agents.structured_output import ResponseFormat
1922

23+
__all__ = [
24+
"AgentMiddleware",
25+
"AgentState",
26+
"ContextT",
27+
"ModelRequest",
28+
"OmitFromSchema",
29+
"PublicAgentState",
30+
]
31+
2032
JumpTo = Literal["tools", "model", "__end__"]
2133
"""Destination to jump to when a middleware node returns."""
2234

@@ -36,26 +48,49 @@ class ModelRequest:
3648
model_settings: dict[str, Any] = field(default_factory=dict)
3749

3850

51+
@dataclass
52+
class OmitFromSchema:
53+
"""Annotation used to mark state attributes as omitted from input or output schemas."""
54+
55+
input: bool = True
56+
"""Whether to omit the attribute from the input schema."""
57+
58+
output: bool = True
59+
"""Whether to omit the attribute from the output schema."""
60+
61+
62+
OmitFromInput = OmitFromSchema(input=True, output=False)
63+
"""Annotation used to mark state attributes as omitted from input schema."""
64+
65+
OmitFromOutput = OmitFromSchema(input=False, output=True)
66+
"""Annotation used to mark state attributes as omitted from output schema."""
67+
68+
PrivateStateAttr = OmitFromSchema(input=True, output=True)
69+
"""Annotation used to mark state attributes as purely internal for a given middleware."""
70+
71+
3972
class AgentState(TypedDict, Generic[ResponseT]):
4073
"""State schema for the agent."""
4174

4275
messages: Required[Annotated[list[AnyMessage], add_messages]]
43-
model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]]
44-
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
76+
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
4577
response: NotRequired[ResponseT]
4678

4779

4880
class PublicAgentState(TypedDict, Generic[ResponseT]):
49-
"""Input / output schema for the agent."""
81+
"""Public state schema for the agent.
5082
51-
messages: Required[Messages]
83+
Just used for typing purposes.
84+
"""
85+
86+
messages: Required[Annotated[list[AnyMessage], add_messages]]
5287
response: NotRequired[ResponseT]
5388

5489

55-
StateT = TypeVar("StateT", bound=AgentState)
90+
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
5691

5792

58-
class AgentMiddleware(Generic[StateT]):
93+
class AgentMiddleware(Generic[StateT, ContextT]):
5994
"""Base middleware class for an agent.
6095
6196
Subclass this and implement any of the defined methods to customize agent behavior
@@ -68,12 +103,17 @@ class AgentMiddleware(Generic[StateT]):
68103
tools: list[BaseTool]
69104
"""Additional tools registered by the middleware."""
70105

71-
def before_model(self, state: StateT) -> dict[str, Any] | None:
106+
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
72107
"""Logic to run before the model is called."""
73108

74-
def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002
109+
def modify_model_request(
110+
self,
111+
request: ModelRequest,
112+
state: StateT, # noqa: ARG002
113+
runtime: Runtime[ContextT], # noqa: ARG002
114+
) -> ModelRequest:
75115
"""Logic to modify request kwargs before the model is called."""
76116
return request
77117

78-
def after_model(self, state: StateT) -> dict[str, Any] | None:
118+
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
79119
"""Logic to run after the model is called."""

0 commit comments

Comments
 (0)