Skip to content

Commit d001147

Browse files
committed
gen: #file:agent.py の ChatWithToolsAgent のコンストラクタからシステムプロンプトを注入するインタフェースを追加して。互換性も担保した形式で実装して。
1 parent a74abda commit d001147

File tree

1 file changed

+16
-2
lines changed
  • template_langgraph/agents/chat_with_tools_agent

1 file changed

+16
-2
lines changed

template_langgraph/agents/chat_with_tools_agent/agent.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import json
33

4-
from langchain_core.messages import ToolMessage
4+
from langchain_core.messages import SystemMessage, ToolMessage
55
from langgraph.graph import END, StateGraph
66

77
from template_langgraph.agents.chat_with_tools_agent.models import AgentState
@@ -55,11 +55,14 @@ def __init__(
5555
tools=get_default_tools(),
5656
checkpointer=None,
5757
store=None,
58+
system_prompt: str | None = None,
5859
):
5960
self.llm = AzureOpenAiWrapper().chat_model
6061
self.tools = tools
6162
self.checkpointer = checkpointer
6263
self.store = store
64+
self.system_prompt = system_prompt
65+
self._system_message = SystemMessage(content=system_prompt) if system_prompt else None
6366

6467
def create_graph(self):
6568
"""Create the main graph for the agent."""
@@ -100,9 +103,10 @@ def chat_with_tools(self, state: AgentState) -> AgentState:
100103
llm_with_tools = self.llm.bind_tools(
101104
tools=self.tools,
102105
)
106+
messages = self._prepare_messages(state)
103107
return {
104108
"messages": [
105-
llm_with_tools.invoke(state["messages"]),
109+
llm_with_tools.invoke(messages),
106110
]
107111
}
108112

@@ -124,5 +128,15 @@ def route_tools(
124128
return "tools"
125129
return END
126130

131+
def _prepare_messages(self, state: AgentState):
132+
"""Return a message list with the optional system prompt prefixed."""
133+
base_messages = list(state) if isinstance(state, list) else list(state.get("messages", []))
134+
if not self._system_message:
135+
return base_messages
136+
if base_messages and isinstance(base_messages[0], SystemMessage):
137+
if base_messages[0].content == self._system_message.content:
138+
return base_messages
139+
return [self._system_message, *base_messages]
140+
127141

128142
graph = ChatWithToolsAgent().create_graph()

0 commit comments

Comments
 (0)