11from __future__ import annotations
22
3- import abc
4- import json
53import logging
6- import random
7- import string
84import sys
9- from json import JSONDecodeError
10- from pathlib import Path
11- from typing import Union , Any
125
13- import chevron
14- from openai .types .chat import ChatCompletionMessageParam
15- from openai .types .chat .chat_completion_tool_param import ChatCompletionToolParam
6+ from pydantic import BaseModel
167from pydantic_ai import Agent
178from pydantic_ai .models .anthropic import AnthropicModel
18- from pydantic import BaseModel
19-
20- from patchwork .common .client .llm .protocol import LlmClient
21- from patchwork .common .client .llm .utils import example_string_to_base_model , example_json_to_base_model
22- from patchwork .common .tools import CodeEditTool , Tool
23- from patchwork .common .tools .agentic_tools import EndTool
24-
25-
26- class Role (abc .ABC ):
27- def __init__ (self , llm_client : LlmClient , tool_set : dict [str , Tool ]):
28- self .llm_client = llm_client
29- self .tool_set = tool_set
30- self .history : list [ChatCompletionMessageParam ] = []
31-
32- def generate_reply (self , message : str ) -> str :
33- self .history .append (dict (role = "user" , content = message ))
34- input_kwargs = dict (
35- messages = self .history ,
36- model = "claude-3-5-sonnet-latest" ,
37- tools = self .__get_tools_spec (),
38- max_tokens = 8096 ,
39- )
40- is_prompt_safe = self .llm_client .is_prompt_supported (** input_kwargs )
41- if is_prompt_safe < 0 :
42- raise ValueError ("The subsequent prompt is not supported, due to large size." )
43- response = self .llm_client .chat_completion (** input_kwargs )
44- choices = response .choices or []
45-
46- message_content = ""
47- for choice in choices :
48- new_message = choice .message .to_dict ()
49- self .history .append (new_message )
50- if new_message .get ("tool_calls" ) is not None :
51- self .history .extend (self .__execute_tools (new_message ))
52- else :
53- message_content = new_message ["content" ]
54-
55- return message_content
56-
57- def __execute_tools (self , last_message : ChatCompletionMessageParam ) -> list [ChatCompletionMessageParam ]:
58- rv = []
59- for tool_call in last_message .get ("tool_calls" , []):
60- tool_name_to_use = tool_call .get ("function" , {}).get ("name" )
61- tool_to_use = self .tool_set .get (tool_name_to_use , None )
62- if tool_to_use is None :
63- logging .info ("LLM just used an non-existent tool!" )
64- continue
65-
66- logging .info (f"Running tool: { tool_name_to_use } " )
67- try :
68- tool_arguments = tool_call .get ("function" , {}).get ("arguments" , "{}" )
69- tool_kwargs = json .loads (tool_arguments )
70- tool_output = tool_to_use .execute (** tool_kwargs )
71- except JSONDecodeError :
72- tool_output = "Arguments must be passed through a valid JSON object"
73-
74- rv .append ({"tool_call_id" : tool_call .get ("id" , "" ), "role" : "tool" , "content" : tool_output })
75-
76- return rv
77-
78- def __get_tools_spec (self ) -> list [ChatCompletionToolParam ]:
79- return [
80- dict (
81- type = "function" ,
82- function = {"name" : k , ** v .json_schema },
83- )
84- for k , v in self .tool_set .items ()
85- ]
86-
87-
88- class UserProxy (Role ):
89- def __init__ (
90- self , llm_client : LlmClient , tool_set : dict [str , Tool ], system_prompt : str = None , reply_message : str = ""
91- ):
92- super ().__init__ (llm_client , tool_set )
93- if system_prompt is not None :
94- self .history .append (dict (role = "system" , content = system_prompt ))
95-
96- self .__reply_message = reply_message
97-
98- def generate_reply (self , message : str ) -> str :
99- if self .__reply_message is not None :
100- self .history .append (dict (role = "user" , content = message ))
101- self .history .append (dict (role = "assistant" , content = self .__reply_message ))
102- return self .__reply_message
103- else :
104- return super ().generate_reply (message )
105-
106-
107- class Assistant (Role ):
108- def __init__ (self , llm_client : LlmClient , tool_set : dict [str , Tool ], system_prompt : str = None ):
109- super ().__init__ (llm_client , tool_set )
110- if system_prompt is not None :
111- self .history .append (dict (role = "system" , content = system_prompt ))
9+ from typing_extensions import Any , Dict , Optional , Union
11210
11+ from patchwork .common .client .llm .utils import example_json_to_base_model
12+ from patchwork .common .tools import Tool
13+ from patchwork .common .utils .utils import mustache_render
11314
15+ _COMPLETION_FLAG_ATTRIBUTE = "is_task_completed"
16+ _MESSAGE_ATTRIBUTE = "message"
11417
11518
11619class AgentConfig (BaseModel ):
20+ class Config :
21+ arbitrary_types_allowed = True
22+
11723 name : str
118- tool_set : dict [str , Tool ]
119- system_prompt : str = ''
24+ tool_set : Dict [str , Tool ]
25+ system_prompt : str = ""
26+ example_json : Union [
27+ str , Dict [str , Any ]
28+ ] = f'{{"{ _MESSAGE_ATTRIBUTE } ":"message", "{ _COMPLETION_FLAG_ATTRIBUTE } ": false}}'
12029
12130
12231class AgenticStrategy :
@@ -128,73 +37,75 @@ def __init__(
12837 user_prompt_template : str ,
12938 agent_configs : list [AgentConfig ],
13039 example_json : Union [str , dict [str , Any ]] = '{"output":"output text"}' ,
131- * args ,
132- ** kwargs ,
40+ limit : Optional [int ] = None ,
13341 ):
42+ self .__limit = limit
13443 self .__template_data = template_data
13544 self .__user_prompt_template = user_prompt_template
13645 model = AnthropicModel ("claude-3-5-sonnet-latest" , api_key = api_key )
137- self .__user_role = Agent (
46+ self .__summariser = Agent (
13847 model ,
139- system_prompt = self . __render_prompt (system_prompt_template ),
48+ system_prompt = mustache_render (system_prompt_template , self . __template_data ),
14049 result_type = example_json_to_base_model (example_json ),
50+ model_settings = dict (parallel_tool_calls = False ),
14151 )
142- self .__assistants = []
143- for assistant_config in agent_configs :
52+ self .__agents = []
53+ for agent_config in agent_configs :
14454 tools = []
145- for tool in assistant_config .tool_set .values ():
55+ for tool in agent_config .tool_set .values ():
14656 tools .append (tool .to_pydantic_ai_function_tool ())
147- assistant = Agent (
148- "claude-3-5-sonnet-latest" ,
149- system_prompt = self .__render_prompt (assistant_config .system_prompt ),
150- tools = tools
57+ agent = Agent (
58+ model ,
59+ name = agent_config .name ,
60+ system_prompt = mustache_render (agent_config .system_prompt , self .__template_data ),
61+ tools = tools ,
62+ result_type = example_json_to_base_model (agent_config .example_json ),
63+ model_settings = dict (parallel_tool_calls = False ),
15164 )
15265
153- self .__assistants .append (assistant )
66+ self .__agents .append (agent )
15467
155- def __render_prompt (self , prompt_template : str ) -> str :
156- chevron .render .__globals__ ["_html_escape" ] = lambda x : x
157- return chevron .render (
158- template = prompt_template ,
159- data = self .__template_data ,
160- partials_path = None ,
161- partials_ext = "" .join (random .choices (string .ascii_uppercase + string .digits , k = 32 )),
162- partials_dict = dict (),
163- )
164-
165- def __is_session_completed (self ) -> bool :
166- for message in reversed (self .__assistant_role .history ):
167- if message .get ("tool" ) is not None :
168- continue
169- if message .get ("content" ) == EndTool .MESSAGE :
170- return True
171-
172- return False
173-
174- def execute (self , limit : int | None = None ) -> None :
175- message = self .__render_prompt (self .__user_prompt_template )
68+ def execute (self , limit : Optional [int ] = None ) -> dict :
69+ agents_result = dict ()
17670 try :
177- for i in range (limit or self .__limit or sys .maxsize ):
178- self .__user_role .run_sync (self .__user_prompt_template )
179- self .run_count = i + 1
180- for role in [* self .__assistants , self .__user_role ]:
181- message = role .run_sync (message )
182- if self .__is_session_completed ():
183- break
71+ for index , agent in enumerate (self .__agents ):
72+ user_message = mustache_render (self .__user_prompt_template , self .__template_data )
73+ message_history = None
74+ agent_output = None
75+ for i in range (limit or self .__limit or sys .maxsize ):
76+ agent_output = agent .run_sync (user_message , message_history = message_history )
77+ message_history = agent_output .all_messages ()
78+ if getattr (agent_output .data , _COMPLETION_FLAG_ATTRIBUTE , False ):
79+ break
80+ user_message = "Please continue"
81+ agents_result [index ] = agent_output
18482 except Exception as e :
18583 logging .error (e )
186- finally :
187- self .run_count = 0
18884
189- @property
190- def history (self ):
191- return self .__user_role .history
85+ if len (agents_result ) == 0 :
86+ return dict ()
87+
88+ if len (agents_result ) == 1 :
89+ final_result = self .__summariser .run_sync (
90+ "From the actions taken by the assistant. Please give me the result." ,
91+ message_history = next (v for _ , v in agents_result .items ()).all_messages (),
92+ )
93+ else :
94+ agent_summaries = []
95+ for agent_result in agents_result .values ():
96+ agent_summary_result = self .__summariser .run_sync (
97+ "From the actions taken by the assistant. Please give me the result." ,
98+ message_history = agent_result .all_messages (),
99+ )
100+ agent_summary = getattr (agent_summary_result .data , _MESSAGE_ATTRIBUTE , None )
101+ if agent_summary is None :
102+ continue
103+
104+ agent_summaries .append (agent_summary )
105+ agent_summary_list = "\n * " + "\n * " .join (agent_summaries )
106+ final_result = self .__summariser .run_sync (
107+ "Please give me the result from the following summary of what the assistants have done."
108+ + agent_summary_list ,
109+ )
192110
193- @property
194- def tool_records (self ):
195- for tool in self .tool_set .values ():
196- if isinstance (tool , CodeEditTool ):
197- cwd = Path .cwd ()
198- modified_files = [file_path .relative_to (cwd ) for file_path in tool .tool_records ["modified_files" ]]
199- return [dict (path = str (file )) for file in modified_files ]
200- return []
111+ return final_result .data .dict ()
0 commit comments