11import asyncio
22import json
33
4- from langchain_core .messages import ToolMessage
4+ from langchain_core .messages import SystemMessage , ToolMessage
55from langgraph .graph import END , StateGraph
66
77from template_langgraph .agents .chat_with_tools_agent .models import AgentState
@@ -55,11 +55,14 @@ def __init__(
5555 tools = get_default_tools (),
5656 checkpointer = None ,
5757 store = None ,
58+ system_prompt : str | None = None ,
5859 ):
5960 self .llm = AzureOpenAiWrapper ().chat_model
6061 self .tools = tools
6162 self .checkpointer = checkpointer
6263 self .store = store
64+ self .system_prompt = system_prompt
65+ self ._system_message = SystemMessage (content = system_prompt ) if system_prompt else None
6366
6467 def create_graph (self ):
6568 """Create the main graph for the agent."""
@@ -100,9 +103,10 @@ def chat_with_tools(self, state: AgentState) -> AgentState:
100103 llm_with_tools = self .llm .bind_tools (
101104 tools = self .tools ,
102105 )
106+ messages = self ._prepare_messages (state )
103107 return {
104108 "messages" : [
105- llm_with_tools .invoke (state [ " messages" ] ),
109+ llm_with_tools .invoke (messages ),
106110 ]
107111 }
108112
@@ -124,5 +128,15 @@ def route_tools(
124128 return "tools"
125129 return END
126130
131+ def _prepare_messages (self , state : AgentState ):
132+ """Return a message list with the optional system prompt prefixed."""
133+ base_messages = list (state ) if isinstance (state , list ) else list (state .get ("messages" , []))
134+ if not self ._system_message :
135+ return base_messages
136+ if base_messages and isinstance (base_messages [0 ], SystemMessage ):
137+ if base_messages [0 ].content == self ._system_message .content :
138+ return base_messages
139+ return [self ._system_message , * base_messages ]
140+
127141
128142graph = ChatWithToolsAgent ().create_graph ()
0 commit comments