|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import abc |
| 4 | +import json |
3 | 5 | import logging |
| 6 | +import random |
| 7 | +import string |
4 | 8 | import sys |
| 9 | +from json import JSONDecodeError |
| 10 | +from pathlib import Path |
5 | 11 |
|
6 | | -from pydantic import BaseModel |
7 | | -from pydantic_ai import Agent |
8 | | -from pydantic_ai.models.anthropic import AnthropicModel |
9 | | -from typing_extensions import Any, Dict, Optional, Union |
| 12 | +import chevron |
| 13 | +from openai.types.chat import ChatCompletionMessageParam |
| 14 | +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam |
10 | 15 |
|
11 | | -from patchwork.common.client.llm.utils import example_json_to_base_model |
12 | | -from patchwork.common.tools import Tool |
13 | | -from patchwork.common.utils.utils import mustache_render |
| 16 | +from patchwork.common.client.llm.protocol import LlmClient |
| 17 | +from patchwork.common.tools import CodeEditTool, Tool |
| 18 | +from patchwork.common.tools.agentic_tools import EndTool |
14 | 19 |
|
15 | | -_COMPLETION_FLAG_ATTRIBUTE = "is_task_completed" |
16 | | -_MESSAGE_ATTRIBUTE = "message" |
17 | 20 |
|
| 21 | +class Role(abc.ABC): |
| 22 | + def __init__(self, llm_client: LlmClient, tool_set: dict[str, Tool]): |
| 23 | + self.llm_client = llm_client |
| 24 | + self.tool_set = tool_set |
| 25 | + self.history: list[ChatCompletionMessageParam] = [] |
18 | 26 |
|
19 | | -class AgentConfig(BaseModel): |
20 | | - class Config: |
21 | | - arbitrary_types_allowed = True |
| 27 | + def generate_reply(self, message: str) -> str: |
| 28 | + self.history.append(dict(role="user", content=message)) |
| 29 | + input_kwargs = dict( |
| 30 | + messages=self.history, |
| 31 | + model="claude-3-5-sonnet-latest", |
| 32 | + tools=self.__get_tools_spec(), |
| 33 | + max_tokens=8096, |
| 34 | + ) |
| 35 | + is_prompt_safe = self.llm_client.is_prompt_supported(**input_kwargs) |
| 36 | + if is_prompt_safe < 0: |
| 37 | + raise ValueError("The subsequent prompt is not supported, due to large size.") |
| 38 | + response = self.llm_client.chat_completion(**input_kwargs) |
| 39 | + choices = response.choices or [] |
| 40 | + |
| 41 | + message_content = "" |
| 42 | + for choice in choices: |
| 43 | + new_message = choice.message.to_dict() |
| 44 | + self.history.append(new_message) |
| 45 | + if new_message.get("tool_calls") is not None: |
| 46 | + self.history.extend(self.__execute_tools(new_message)) |
| 47 | + else: |
| 48 | + message_content = new_message["content"] |
| 49 | + |
| 50 | + return message_content |
| 51 | + |
| 52 | + def __execute_tools(self, last_message: ChatCompletionMessageParam) -> list[ChatCompletionMessageParam]: |
| 53 | + rv = [] |
| 54 | + for tool_call in last_message.get("tool_calls", []): |
| 55 | + tool_name_to_use = tool_call.get("function", {}).get("name") |
| 56 | + tool_to_use = self.tool_set.get(tool_name_to_use, None) |
| 57 | + if tool_to_use is None: |
| 58 | + logging.info("LLM just used an non-existent tool!") |
| 59 | + continue |
22 | 60 |
|
23 | | - name: str |
24 | | - tool_set: Dict[str, Tool] |
25 | | - system_prompt: str = "" |
26 | | - example_json: Union[ |
27 | | - str, Dict[str, Any] |
28 | | - ] = f'{{"{_MESSAGE_ATTRIBUTE}":"message", "{_COMPLETION_FLAG_ATTRIBUTE}": false}}' |
| 61 | + logging.info(f"Running tool: {tool_name_to_use}") |
| 62 | + try: |
| 63 | + tool_arguments = tool_call.get("function", {}).get("arguments", "{}") |
| 64 | + tool_kwargs = json.loads(tool_arguments) |
| 65 | + tool_output = tool_to_use.execute(**tool_kwargs) |
| 66 | + except JSONDecodeError: |
| 67 | + tool_output = "Arguments must be passed through a valid JSON object" |
| 68 | + |
| 69 | + rv.append({"tool_call_id": tool_call.get("id", ""), "role": "tool", "content": tool_output}) |
| 70 | + |
| 71 | + return rv |
| 72 | + |
| 73 | + def __get_tools_spec(self) -> list[ChatCompletionToolParam]: |
| 74 | + return [ |
| 75 | + dict( |
| 76 | + type="function", |
| 77 | + function={"name": k, **v.json_schema}, |
| 78 | + ) |
| 79 | + for k, v in self.tool_set.items() |
| 80 | + ] |
| 81 | + |
| 82 | + |
| 83 | +class UserProxy(Role): |
| 84 | + def __init__( |
| 85 | + self, llm_client: LlmClient, tool_set: dict[str, Tool], system_prompt: str = None, reply_message: str = "" |
| 86 | + ): |
| 87 | + super().__init__(llm_client, tool_set) |
| 88 | + if system_prompt is not None: |
| 89 | + self.history.append(dict(role="system", content=system_prompt)) |
| 90 | + |
| 91 | + self.__reply_message = reply_message |
| 92 | + |
| 93 | + def generate_reply(self, message: str) -> str: |
| 94 | + if self.__reply_message is not None: |
| 95 | + self.history.append(dict(role="user", content=message)) |
| 96 | + self.history.append(dict(role="assistant", content=self.__reply_message)) |
| 97 | + return self.__reply_message |
| 98 | + else: |
| 99 | + return super().generate_reply(message) |
| 100 | + |
| 101 | + |
| 102 | +class Assistant(Role): |
| 103 | + def __init__(self, llm_client: LlmClient, tool_set: dict[str, Tool], system_prompt: str = None): |
| 104 | + super().__init__(llm_client, tool_set) |
| 105 | + if system_prompt is not None: |
| 106 | + self.history.append(dict(role="system", content=system_prompt)) |
29 | 107 |
|
30 | 108 |
|
31 | 109 | class AgenticStrategy: |
32 | 110 | def __init__( |
33 | 111 | self, |
34 | | - api_key: str, |
| 112 | + llm_client: LlmClient, |
| 113 | + tool_set: dict[str, Tool], |
35 | 114 | template_data: dict[str, str], |
36 | 115 | system_prompt_template: str, |
37 | 116 | user_prompt_template: str, |
38 | | - agent_configs: list[AgentConfig], |
39 | | - example_json: Union[str, dict[str, Any]] = '{"output":"output text"}', |
40 | | - limit: Optional[int] = None, |
| 117 | + *args, |
| 118 | + **kwargs, |
41 | 119 | ): |
42 | | - self.__limit = limit |
| 120 | + self.tool_set = dict(end=EndTool(), **tool_set) |
43 | 121 | self.__template_data = template_data |
44 | 122 | self.__user_prompt_template = user_prompt_template |
45 | | - model = AnthropicModel("claude-3-5-sonnet-latest", api_key=api_key) |
46 | | - self.__summariser = Agent( |
47 | | - model, |
48 | | - system_prompt=mustache_render(system_prompt_template, self.__template_data), |
49 | | - result_type=example_json_to_base_model(example_json), |
50 | | - model_settings=dict(parallel_tool_calls=False), |
| 123 | + self.__assistant_role = Assistant(llm_client, self.tool_set, self.__render_prompt(system_prompt_template)) |
| 124 | + self.__user_role = UserProxy(llm_client, dict()) |
| 125 | + |
| 126 | + def __render_prompt(self, prompt_template: str) -> str: |
| 127 | + chevron.render.__globals__["_html_escape"] = lambda x: x |
| 128 | + return chevron.render( |
| 129 | + template=prompt_template, |
| 130 | + data=self.__template_data, |
| 131 | + partials_path=None, |
| 132 | + partials_ext="".join(random.choices(string.ascii_uppercase + string.digits, k=32)), |
| 133 | + partials_dict=dict(), |
51 | 134 | ) |
52 | | - self.__agents = [] |
53 | | - for agent_config in agent_configs: |
54 | | - tools = [] |
55 | | - for tool in agent_config.tool_set.values(): |
56 | | - tools.append(tool.to_pydantic_ai_function_tool()) |
57 | | - agent = Agent( |
58 | | - model, |
59 | | - name=agent_config.name, |
60 | | - system_prompt=mustache_render(agent_config.system_prompt, self.__template_data), |
61 | | - tools=tools, |
62 | | - result_type=example_json_to_base_model(agent_config.example_json), |
63 | | - model_settings=dict(parallel_tool_calls=False), |
64 | | - ) |
65 | 135 |
|
66 | | - self.__agents.append(agent) |
| 136 | + def __get_initial_prompt(self) -> list[ChatCompletionMessageParam]: |
| 137 | + return [dict(role="user", content=self.__render_prompt(self.__user_prompt_template))] |
67 | 138 |
|
68 | | - def execute(self, limit: Optional[int] = None) -> dict: |
69 | | - agents_result = dict() |
| 139 | + def __is_session_completed(self) -> bool: |
| 140 | + for message in reversed(self.__assistant_role.history): |
| 141 | + if message.get("tool") is not None: |
| 142 | + continue |
| 143 | + if message.get("content") == EndTool.MESSAGE: |
| 144 | + return True |
| 145 | + |
| 146 | + return False |
| 147 | + |
| 148 | + def execute(self, limit: int | None = None) -> None: |
| 149 | + message = self.__render_prompt(self.__user_prompt_template) |
70 | 150 | try: |
71 | | - for index, agent in enumerate(self.__agents): |
72 | | - user_message = mustache_render(self.__user_prompt_template, self.__template_data) |
73 | | - message_history = None |
74 | | - agent_output = None |
75 | | - for i in range(limit or self.__limit or sys.maxsize): |
76 | | - agent_output = agent.run_sync(user_message, message_history=message_history) |
77 | | - message_history = agent_output.all_messages() |
78 | | - if getattr(agent_output.data, _COMPLETION_FLAG_ATTRIBUTE, False): |
79 | | - break |
80 | | - user_message = "Please continue" |
81 | | - agents_result[index] = agent_output |
| 151 | + for i in range(limit or self.__limit or sys.maxsize): |
| 152 | + self.run_count = i + 1 |
| 153 | + for role in [self.__assistant_role, self.__user_role]: |
| 154 | + message = role.generate_reply(message) |
| 155 | + if self.__is_session_completed(): |
| 156 | + break |
82 | 157 | except Exception as e: |
83 | 158 | logging.error(e) |
| 159 | + finally: |
| 160 | + self.run_count = 0 |
84 | 161 |
|
85 | | - if len(agents_result) == 0: |
86 | | - return dict() |
87 | | - |
88 | | - if len(agents_result) == 1: |
89 | | - final_result = self.__summariser.run_sync( |
90 | | - "From the actions taken by the assistant. Please give me the result.", |
91 | | - message_history=next(v for _, v in agents_result.items()).all_messages(), |
92 | | - ) |
93 | | - else: |
94 | | - agent_summaries = [] |
95 | | - for agent_result in agents_result.values(): |
96 | | - agent_summary_result = self.__summariser.run_sync( |
97 | | - "From the actions taken by the assistant. Please give me the result.", |
98 | | - message_history=agent_result.all_messages(), |
99 | | - ) |
100 | | - agent_summary = getattr(agent_summary_result.data, _MESSAGE_ATTRIBUTE, None) |
101 | | - if agent_summary is None: |
102 | | - continue |
103 | | - |
104 | | - agent_summaries.append(agent_summary) |
105 | | - agent_summary_list = "\n* " + "\n* ".join(agent_summaries) |
106 | | - final_result = self.__summariser.run_sync( |
107 | | - "Please give me the result from the following summary of what the assistants have done." |
108 | | - + agent_summary_list, |
109 | | - ) |
| 162 | + @property |
| 163 | + def history(self): |
| 164 | + return self.__user_role.history |
110 | 165 |
|
111 | | - return final_result.data.dict() |
| 166 | + @property |
| 167 | + def tool_records(self): |
| 168 | + for tool in self.tool_set.values(): |
| 169 | + if isinstance(tool, CodeEditTool): |
| 170 | + cwd = Path.cwd() |
| 171 | + modified_files = [file_path.relative_to(cwd) for file_path in tool.tool_records["modified_files"]] |
| 172 | + return [dict(path=str(file)) for file in modified_files] |
| 173 | + return [] |
0 commit comments