diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index eb07787929124..e221f9ac69eba 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -545,7 +545,7 @@ def create_agent( # noqa: PLR0915 model: str | BaseChatModel, tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None, *, - system_prompt: str | None = None, + system_prompt: str | SystemMessage | None = None, middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (), response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None, state_schema: type[AgentState[ResponseT]] | None = None, @@ -591,9 +591,9 @@ def create_agent( # noqa: PLR0915 docs for more information. system_prompt: An optional system prompt for the LLM. - Prompts are converted to a - [`SystemMessage`][langchain.messages.SystemMessage] and added to the - beginning of the message list. + Can be a `str` (which will be converted to a `SystemMessage`) or a + `SystemMessage` instance directly. The system message is added to the + beginning of the message list when calling the model. middleware: A sequence of middleware instances to apply to the agent. Middleware can intercept and modify agent behavior at various stages. @@ -688,6 +688,14 @@ def check_weather(location: str) -> str: if isinstance(model, str): model = init_chat_model(model) + # Convert system_prompt to SystemMessage if needed + system_message: SystemMessage | None = None + if system_prompt is not None: + if isinstance(system_prompt, SystemMessage): + system_message = system_prompt + else: + system_message = SystemMessage(content=system_prompt) + # Handle tools being None or empty if tools is None: tools = [] @@ -1091,8 +1099,8 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse: # Get the bound model (with auto-detection if needed) model_, effective_response_format = _get_bound_model(request) messages = request.messages - if request.system_prompt: - messages = [SystemMessage(request.system_prompt), *messages] + if request.system_message: + messages = [request.system_message, *messages] output = model_.invoke(messages) @@ -1111,7 +1119,7 @@ def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: request = ModelRequest( model=model, tools=default_tools, - system_prompt=system_prompt, + system_message=system_message, response_format=initial_response_format, messages=state["messages"], tool_choice=None, @@ -1144,8 +1152,8 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse: # Get the bound model (with auto-detection if needed) model_, effective_response_format = _get_bound_model(request) messages = request.messages - if request.system_prompt: - messages = [SystemMessage(request.system_prompt), *messages] + if request.system_message: + messages = [request.system_message, *messages] output = await model_.ainvoke(messages) @@ -1164,7 +1172,7 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str request = ModelRequest( model=model, tools=default_tools, - system_prompt=system_prompt, + system_message=system_message, response_format=initial_response_format, messages=state["messages"], tool_choice=None, diff --git a/libs/langchain_v1/langchain/agents/middleware/context_editing.py b/libs/langchain_v1/langchain/agents/middleware/context_editing.py index 42488276abda3..538b6ee56766e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/context_editing.py +++ b/libs/langchain_v1/langchain/agents/middleware/context_editing.py @@ -18,7 +18,6 @@ AIMessage, AnyMessage, BaseMessage, - SystemMessage, ToolMessage, ) from langchain_core.messages.utils import count_tokens_approximately @@ -230,9 +229,7 @@ def wrap_model_call( def count_tokens(messages: Sequence[BaseMessage]) -> int: return count_tokens_approximately(messages) else: - system_msg = ( - [SystemMessage(content=request.system_prompt)] if request.system_prompt else [] - ) + system_msg = [request.system_message] if request.system_message else [] def count_tokens(messages: Sequence[BaseMessage]) -> int: return request.model.get_num_tokens_from_messages( @@ -259,9 +256,7 @@ async def awrap_model_call( def count_tokens(messages: Sequence[BaseMessage]) -> int: return count_tokens_approximately(messages) else: - system_msg = ( - [SystemMessage(content=request.system_prompt)] if request.system_prompt else [] - ) + system_msg = [request.system_message] if request.system_message else [] def count_tokens(messages: Sequence[BaseMessage]) -> int: return request.model.get_num_tokens_from_messages( diff --git a/libs/langchain_v1/langchain/agents/middleware/todo.py b/libs/langchain_v1/langchain/agents/middleware/todo.py index ac5bfbee0881e..ce70de42cab48 100644 --- a/libs/langchain_v1/langchain/agents/middleware/todo.py +++ b/libs/langchain_v1/langchain/agents/middleware/todo.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable -from langchain_core.messages import ToolMessage +from langchain_core.messages import SystemMessage, ToolMessage from langchain_core.tools import tool from langgraph.types import Command from typing_extensions import NotRequired, TypedDict @@ -193,23 +193,25 @@ def wrap_model_call( request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: - """Update the system prompt to include the todo system prompt.""" - new_system_prompt = ( - request.system_prompt + "\n\n" + self.system_prompt - if request.system_prompt + """Update the system message to include the todo system prompt.""" + new_system_content = ( + request.system_message.content + "\n\n" + self.system_prompt + if request.system_message else self.system_prompt ) - return handler(request.override(system_prompt=new_system_prompt)) + new_system_message = SystemMessage(content=new_system_content) + return handler(request.override(system_message=new_system_message)) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: - """Update the system prompt to include the todo system prompt (async version).""" - new_system_prompt = ( - request.system_prompt + "\n\n" + self.system_prompt - if request.system_prompt + """Update the system message to include the todo system prompt (async version).""" + new_system_content = ( + request.system_message.content + "\n\n" + self.system_prompt + if request.system_message else self.system_prompt ) - return await handler(request.override(system_prompt=new_system_prompt)) + new_system_message = SystemMessage(content=new_system_content) + return await handler(request.override(system_message=new_system_message)) diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 998fec8571655..5e8574620e2e4 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -20,12 +20,14 @@ from collections.abc import Awaitable # Needed as top level import for Pydantic schema generation on AgentState +import warnings from typing import TypeAlias -from langchain_core.messages import ( # noqa: TC002 +from langchain_core.messages import ( AIMessage, AnyMessage, BaseMessage, + SystemMessage, ToolMessage, ) from langgraph.channels.ephemeral_value import EphemeralValue @@ -72,7 +74,7 @@ class _ModelRequestOverrides(TypedDict, total=False): """Possible overrides for `ModelRequest.override()` method.""" model: BaseChatModel - system_prompt: str | None + system_message: SystemMessage | None messages: list[AnyMessage] tool_choice: Any | None tools: list[BaseTool | dict] @@ -80,13 +82,13 @@ class _ModelRequestOverrides(TypedDict, total=False): model_settings: dict[str, Any] -@dataclass +@dataclass(init=False) class ModelRequest: """Model request information for the agent.""" model: BaseChatModel - system_prompt: str | None - messages: list[AnyMessage] # excluding system prompt + messages: list[AnyMessage] # excluding system message + system_message: SystemMessage | None tool_choice: Any | None tools: list[BaseTool | dict] response_format: ResponseFormat | None @@ -94,6 +96,65 @@ class ModelRequest: runtime: Runtime[ContextT] # type: ignore[valid-type] model_settings: dict[str, Any] = field(default_factory=dict) + def __init__( + self, + *, + model: BaseChatModel, + messages: list[AnyMessage], + system_message: SystemMessage | None = None, + system_prompt: str | None = None, + tool_choice: Any | None = None, + tools: list[BaseTool | dict] | None = None, + response_format: ResponseFormat | None = None, + state: AgentState | None = None, + runtime: Runtime[ContextT] | None = None, + model_settings: dict[str, Any] | None = None, + ) -> None: + """Initialize ModelRequest with backward compatibility for system_prompt. + + Args: + model: The chat model to use. + messages: List of messages (excluding system prompt). + tool_choice: Tool choice configuration. + tools: List of available tools. + response_format: Response format specification. + state: Agent state. + runtime: Runtime context. + model_settings: Additional model settings. + system_message: System message instance (preferred). + system_prompt: System prompt string (deprecated, converted to SystemMessage). + """ + # Handle system_prompt/system_message conversion and validation + if system_prompt is not None and system_message is not None: + msg = "Cannot specify both system_prompt and system_message" + raise ValueError(msg) + + if system_prompt is not None: + system_message = SystemMessage(content=system_prompt) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + self.model = model + self.messages = messages + self.system_message = system_message + self.tool_choice = tool_choice + self.tools = tools if tools is not None else [] + self.response_format = response_format + self.state = state if state is not None else {"messages": []} + self.runtime = runtime # type: ignore[assignment] + self.model_settings = model_settings if model_settings is not None else {} + + @property + def system_prompt(self) -> str | None: + """Get system prompt text from system_message. + + Returns: + The content of the system message if present, otherwise `None`. + """ + if self.system_message is None: + return None + return self.system_message.text + def __setattr__(self, name: str, value: Any) -> None: """Set an attribute with a deprecation warning. @@ -104,20 +165,29 @@ def __setattr__(self, name: str, value: Any) -> None: name: Attribute name. value: Attribute value. """ - import warnings - - # Allow setting attributes during __init__ (when object is being constructed) - if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name): - object.__setattr__(self, name, value) - else: + # Special handling for system_prompt - convert to system_message + if name == "system_prompt": warnings.warn( - f"Direct attribute assignment to ModelRequest.{name} is deprecated. " - f"Use request.override({name}=...) instead to create a new request " - f"with the modified attribute.", + "Direct attribute assignment to ModelRequest.system_prompt is deprecated. " + "Use request.override(system_message=SystemMessage(...)) instead to create " + "a new request with the modified system message.", DeprecationWarning, stacklevel=2, ) - object.__setattr__(self, name, value) + if value is None: + object.__setattr__(self, "system_message", None) + else: + object.__setattr__(self, "system_message", SystemMessage(content=value)) + return + + warnings.warn( + f"Direct attribute assignment to ModelRequest.{name} is deprecated. " + f"Use request.override({name}=...) instead to create a new request " + f"with the modified attribute.", + DeprecationWarning, + stacklevel=2, + ) + object.__setattr__(self, name, value) def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest: """Replace the request with a new request with the given overrides. @@ -132,7 +202,8 @@ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest: Supported keys: - `model`: `BaseChatModel` instance - - `system_prompt`: Optional system prompt string + - `system_prompt`: deprecated, use `system_message` instead + - `system_message`: `SystemMessage` instance - `messages`: `list` of messages - `tool_choice`: Tool choice configuration - `tools`: `list` of available tools @@ -149,12 +220,36 @@ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest: new_request = request.override(model=different_model) ``` - !!! example "Override multiple attributes" + !!! example "Override system message (preferred)" ```python - new_request = request.override(system_prompt="New instructions", tool_choice="auto") + from langchain_core.messages import SystemMessage + + new_request = request.override( + system_message=SystemMessage(content="New instructions") + ) + ``` + + !!! example "Override system prompt (backward compatible)" + + ```python + new_request = request.override(system_prompt="New instructions") ``` """ + # Handle system_prompt/system_message conversion + if "system_prompt" in overrides and "system_message" in overrides: + msg = "Cannot specify both system_prompt and system_message" + raise ValueError(msg) + + if "system_prompt" in overrides: + system_prompt = overrides["system_prompt"] + if system_prompt is None: + overrides["system_message"] = None + else: + overrides["system_message"] = SystemMessage(content=system_prompt) + # Remove system_prompt from overrides to avoid conflict + overrides = {k: v for k, v in overrides.items() if k != "system_prompt"} + return replace(self, **overrides) @@ -603,10 +698,12 @@ def __call__( class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc] - """Callable that returns a prompt string given `ModelRequest` (contains state and runtime).""" + """Callable that returns a prompt string or SystemMessage given `ModelRequest`.""" - def __call__(self, request: ModelRequest) -> str | Awaitable[str]: - """Generate a system prompt string based on the request.""" + def __call__( + self, request: ModelRequest + ) -> str | SystemMessage | Awaitable[str | SystemMessage]: + """Generate a system prompt string or SystemMessage based on the request.""" ... @@ -1326,7 +1423,8 @@ def dynamic_prompt( function that can be applied to a function. The decorated function should return: - - `str` – The system prompt to use for the model request + - `str` – The system prompt string to use for the model request + - `SystemMessage` – A complete system message to use for the model request Examples: Basic usage with dynamic content: @@ -1369,7 +1467,10 @@ async def async_wrapped( handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: prompt = await func(request) # type: ignore[misc] - request = request.override(system_prompt=prompt) + if isinstance(prompt, SystemMessage): + request = request.override(system_message=prompt) + else: + request = request.override(system_message=SystemMessage(content=prompt)) return await handler(request) middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware")) @@ -1389,8 +1490,11 @@ def wrapped( request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: - prompt = cast("str", func(request)) - request = request.override(system_prompt=prompt) + prompt = func(request) + if isinstance(prompt, SystemMessage): + request = request.override(system_message=prompt) + else: + request = request.override(system_message=SystemMessage(content=prompt)) return handler(request) async def async_wrapped_from_sync( @@ -1399,8 +1503,11 @@ async def async_wrapped_from_sync( handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: # Delegate to sync function - prompt = cast("str", func(request)) - request = request.override(system_prompt=prompt) + prompt = func(request) + if isinstance(prompt, SystemMessage): + request = request.override(system_message=prompt) + else: + request = request.override(system_message=SystemMessage(content=prompt)) return await handler(request) middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware")) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_system_message_support.py b/libs/langchain_v1/tests/unit_tests/agents/test_system_message_support.py new file mode 100644 index 0000000000000..d56696b370b4f --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_system_message_support.py @@ -0,0 +1,263 @@ +"""Unit tests for SystemMessage support in create_agent and ModelRequest.""" + +import pytest +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from langchain.agents.middleware.types import ModelRequest + + +class TestModelRequestSystemMessage: + """Test ModelRequest with system_message field.""" + + def test_create_with_system_message(self) -> None: + """Test creating ModelRequest with SystemMessage.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + system_msg = SystemMessage(content="You are a helpful assistant") + + request = ModelRequest( + model=model, + system_message=system_msg, + messages=[HumanMessage("Hi")], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + assert request.system_message == system_msg + assert request.system_prompt == "You are a helpful assistant" + + def test_create_with_none_system_message(self) -> None: + """Test creating ModelRequest with None system_message.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + request = ModelRequest( + model=model, + system_message=None, + messages=[HumanMessage("Hi")], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + assert request.system_message is None + assert request.system_prompt is None + + def test_system_prompt_property_with_string_content(self) -> None: + """Test system_prompt property returns content from SystemMessage.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + system_msg = SystemMessage(content="Test prompt") + + request = ModelRequest( + model=model, + system_message=system_msg, + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + assert request.system_prompt == "Test prompt" + + def test_system_prompt_property_with_list_content(self) -> None: + """Test system_prompt property handles list content.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + system_msg = SystemMessage(content=["Part 1", "Part 2"]) + + request = ModelRequest( + model=model, + system_message=system_msg, + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + # Should convert list content to string + assert request.system_prompt is not None + assert "Part 1" in request.system_prompt + + def test_override_with_system_message(self) -> None: + """Test override() with system_message parameter.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + original_msg = SystemMessage(content="Original") + new_msg = SystemMessage(content="New") + + original_request = ModelRequest( + model=model, + system_message=original_msg, + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + new_request = original_request.override(system_message=new_msg) + + assert new_request.system_message == new_msg + assert new_request.system_prompt == "New" + assert original_request.system_message == original_msg + assert original_request.system_prompt == "Original" + + def test_override_with_system_prompt_backward_compat(self) -> None: + """Test override() with system_prompt parameter (backward compatibility).""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + original_request = ModelRequest( + model=model, + system_message=SystemMessage(content="Original"), + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + # Override using system_prompt (backward compat) + new_request = original_request.override(system_prompt="New prompt") + + assert new_request.system_prompt == "New prompt" + assert isinstance(new_request.system_message, SystemMessage) + assert new_request.system_message.content == "New prompt" + + def test_override_system_prompt_to_none(self) -> None: + """Test override() setting system_prompt to None.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + original_request = ModelRequest( + model=model, + system_message=SystemMessage(content="Original"), + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + new_request = original_request.override(system_prompt=None) + + assert new_request.system_message is None + assert new_request.system_prompt is None + + def test_setattr_system_prompt_deprecated(self) -> None: + """Test that setting system_prompt via setattr raises deprecation warning.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + request = ModelRequest( + model=model, + system_message=None, + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + with pytest.warns(DeprecationWarning, match="system_prompt is deprecated"): + request.system_prompt = "New prompt" + + # Should still work but convert to SystemMessage + assert isinstance(request.system_message, SystemMessage) + assert request.system_message.content == "New prompt" + + def test_multiple_overrides_with_system_message(self) -> None: + """Test chaining overrides with system_message.""" + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + original_request = ModelRequest( + model=model, + system_message=SystemMessage(content="Prompt 1"), + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state={}, + runtime=None, + ) + + final_request = ( + original_request.override(system_message=SystemMessage(content="Prompt 2")) + .override(tool_choice="auto") + .override(system_message=SystemMessage(content="Prompt 3")) + ) + + assert final_request.system_prompt == "Prompt 3" + assert final_request.tool_choice == "auto" + assert original_request.system_prompt == "Prompt 1" + + +class TestCreateAgentSystemMessage: + """Test create_agent with SystemMessage support.""" + + def test_create_agent_with_string_system_prompt(self) -> None: + """Test create_agent accepts string system_prompt (backward compat).""" + from langchain.agents import create_agent + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + # Should not raise - backward compatibility + agent = create_agent( + model=model, + system_prompt="You are a helpful assistant", + ) + + assert agent is not None + + def test_create_agent_with_system_message(self) -> None: + """Test create_agent accepts SystemMessage for system_prompt.""" + from langchain.agents import create_agent + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + system_msg = SystemMessage(content="You are a helpful assistant") + + # Should not raise + agent = create_agent( + model=model, + system_prompt=system_msg, + ) + + assert agent is not None + + def test_create_agent_with_none_system_prompt(self) -> None: + """Test create_agent with None system_prompt.""" + from langchain.agents import create_agent + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + # Should not raise + agent = create_agent( + model=model, + system_prompt=None, + ) + + assert agent is not None + + def test_create_agent_system_message_with_metadata(self) -> None: + """Test create_agent with SystemMessage containing metadata.""" + from langchain.agents import create_agent + + model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + system_msg = SystemMessage( + content="You are a helpful assistant", additional_kwargs={"role": "system_admin"} + ) + + # Should not raise and preserve metadata + agent = create_agent( + model=model, + system_prompt=system_msg, + ) + + assert agent is not None