Skip to content

Commit fabd099

Browse files
committed
add v2 versioning
1 parent 070fce1 commit fabd099

File tree

8 files changed

+245
-108
lines changed

8 files changed

+245
-108
lines changed
Lines changed: 144 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,173 @@
11
from __future__ import annotations
22

3+
import abc
4+
import json
35
import logging
6+
import random
7+
import string
48
import sys
9+
from json import JSONDecodeError
10+
from pathlib import Path
511

6-
from pydantic import BaseModel
7-
from pydantic_ai import Agent
8-
from pydantic_ai.models.anthropic import AnthropicModel
9-
from typing_extensions import Any, Dict, Optional, Union
12+
import chevron
13+
from openai.types.chat import ChatCompletionMessageParam
14+
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
1015

11-
from patchwork.common.client.llm.utils import example_json_to_base_model
12-
from patchwork.common.tools import Tool
13-
from patchwork.common.utils.utils import mustache_render
16+
from patchwork.common.client.llm.protocol import LlmClient
17+
from patchwork.common.tools import CodeEditTool, Tool
18+
from patchwork.common.tools.agentic_tools import EndTool
1419

15-
_COMPLETION_FLAG_ATTRIBUTE = "is_task_completed"
16-
_MESSAGE_ATTRIBUTE = "message"
1720

21+
class Role(abc.ABC):
22+
def __init__(self, llm_client: LlmClient, tool_set: dict[str, Tool]):
23+
self.llm_client = llm_client
24+
self.tool_set = tool_set
25+
self.history: list[ChatCompletionMessageParam] = []
1826

19-
class AgentConfig(BaseModel):
20-
class Config:
21-
arbitrary_types_allowed = True
27+
def generate_reply(self, message: str) -> str:
28+
self.history.append(dict(role="user", content=message))
29+
input_kwargs = dict(
30+
messages=self.history,
31+
model="claude-3-5-sonnet-latest",
32+
tools=self.__get_tools_spec(),
33+
max_tokens=8096,
34+
)
35+
is_prompt_safe = self.llm_client.is_prompt_supported(**input_kwargs)
36+
if is_prompt_safe < 0:
37+
raise ValueError("The subsequent prompt is not supported, due to large size.")
38+
response = self.llm_client.chat_completion(**input_kwargs)
39+
choices = response.choices or []
40+
41+
message_content = ""
42+
for choice in choices:
43+
new_message = choice.message.to_dict()
44+
self.history.append(new_message)
45+
if new_message.get("tool_calls") is not None:
46+
self.history.extend(self.__execute_tools(new_message))
47+
else:
48+
message_content = new_message["content"]
49+
50+
return message_content
51+
52+
def __execute_tools(self, last_message: ChatCompletionMessageParam) -> list[ChatCompletionMessageParam]:
53+
rv = []
54+
for tool_call in last_message.get("tool_calls", []):
55+
tool_name_to_use = tool_call.get("function", {}).get("name")
56+
tool_to_use = self.tool_set.get(tool_name_to_use, None)
57+
if tool_to_use is None:
58+
logging.info("LLM just used an non-existent tool!")
59+
continue
2260

23-
name: str
24-
tool_set: Dict[str, Tool]
25-
system_prompt: str = ""
26-
example_json: Union[
27-
str, Dict[str, Any]
28-
] = f'{{"{_MESSAGE_ATTRIBUTE}":"message", "{_COMPLETION_FLAG_ATTRIBUTE}": false}}'
61+
logging.info(f"Running tool: {tool_name_to_use}")
62+
try:
63+
tool_arguments = tool_call.get("function", {}).get("arguments", "{}")
64+
tool_kwargs = json.loads(tool_arguments)
65+
tool_output = tool_to_use.execute(**tool_kwargs)
66+
except JSONDecodeError:
67+
tool_output = "Arguments must be passed through a valid JSON object"
68+
69+
rv.append({"tool_call_id": tool_call.get("id", ""), "role": "tool", "content": tool_output})
70+
71+
return rv
72+
73+
def __get_tools_spec(self) -> list[ChatCompletionToolParam]:
74+
return [
75+
dict(
76+
type="function",
77+
function={"name": k, **v.json_schema},
78+
)
79+
for k, v in self.tool_set.items()
80+
]
81+
82+
83+
class UserProxy(Role):
84+
def __init__(
85+
self, llm_client: LlmClient, tool_set: dict[str, Tool], system_prompt: str = None, reply_message: str = ""
86+
):
87+
super().__init__(llm_client, tool_set)
88+
if system_prompt is not None:
89+
self.history.append(dict(role="system", content=system_prompt))
90+
91+
self.__reply_message = reply_message
92+
93+
def generate_reply(self, message: str) -> str:
94+
if self.__reply_message is not None:
95+
self.history.append(dict(role="user", content=message))
96+
self.history.append(dict(role="assistant", content=self.__reply_message))
97+
return self.__reply_message
98+
else:
99+
return super().generate_reply(message)
100+
101+
102+
class Assistant(Role):
103+
def __init__(self, llm_client: LlmClient, tool_set: dict[str, Tool], system_prompt: str = None):
104+
super().__init__(llm_client, tool_set)
105+
if system_prompt is not None:
106+
self.history.append(dict(role="system", content=system_prompt))
29107

30108

31109
class AgenticStrategy:
32110
def __init__(
33111
self,
34-
api_key: str,
112+
llm_client: LlmClient,
113+
tool_set: dict[str, Tool],
35114
template_data: dict[str, str],
36115
system_prompt_template: str,
37116
user_prompt_template: str,
38-
agent_configs: list[AgentConfig],
39-
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
40-
limit: Optional[int] = None,
117+
*args,
118+
**kwargs,
41119
):
42-
self.__limit = limit
120+
self.tool_set = dict(end=EndTool(), **tool_set)
43121
self.__template_data = template_data
44122
self.__user_prompt_template = user_prompt_template
45-
model = AnthropicModel("claude-3-5-sonnet-latest", api_key=api_key)
46-
self.__summariser = Agent(
47-
model,
48-
system_prompt=mustache_render(system_prompt_template, self.__template_data),
49-
result_type=example_json_to_base_model(example_json),
50-
model_settings=dict(parallel_tool_calls=False),
123+
self.__assistant_role = Assistant(llm_client, self.tool_set, self.__render_prompt(system_prompt_template))
124+
self.__user_role = UserProxy(llm_client, dict())
125+
126+
def __render_prompt(self, prompt_template: str) -> str:
127+
chevron.render.__globals__["_html_escape"] = lambda x: x
128+
return chevron.render(
129+
template=prompt_template,
130+
data=self.__template_data,
131+
partials_path=None,
132+
partials_ext="".join(random.choices(string.ascii_uppercase + string.digits, k=32)),
133+
partials_dict=dict(),
51134
)
52-
self.__agents = []
53-
for agent_config in agent_configs:
54-
tools = []
55-
for tool in agent_config.tool_set.values():
56-
tools.append(tool.to_pydantic_ai_function_tool())
57-
agent = Agent(
58-
model,
59-
name=agent_config.name,
60-
system_prompt=mustache_render(agent_config.system_prompt, self.__template_data),
61-
tools=tools,
62-
result_type=example_json_to_base_model(agent_config.example_json),
63-
model_settings=dict(parallel_tool_calls=False),
64-
)
65135

66-
self.__agents.append(agent)
136+
def __get_initial_prompt(self) -> list[ChatCompletionMessageParam]:
137+
return [dict(role="user", content=self.__render_prompt(self.__user_prompt_template))]
67138

68-
def execute(self, limit: Optional[int] = None) -> dict:
69-
agents_result = dict()
139+
def __is_session_completed(self) -> bool:
140+
for message in reversed(self.__assistant_role.history):
141+
if message.get("tool") is not None:
142+
continue
143+
if message.get("content") == EndTool.MESSAGE:
144+
return True
145+
146+
return False
147+
148+
def execute(self, limit: int | None = None) -> None:
149+
message = self.__render_prompt(self.__user_prompt_template)
70150
try:
71-
for index, agent in enumerate(self.__agents):
72-
user_message = mustache_render(self.__user_prompt_template, self.__template_data)
73-
message_history = None
74-
agent_output = None
75-
for i in range(limit or self.__limit or sys.maxsize):
76-
agent_output = agent.run_sync(user_message, message_history=message_history)
77-
message_history = agent_output.all_messages()
78-
if getattr(agent_output.data, _COMPLETION_FLAG_ATTRIBUTE, False):
79-
break
80-
user_message = "Please continue"
81-
agents_result[index] = agent_output
151+
for i in range(limit or self.__limit or sys.maxsize):
152+
self.run_count = i + 1
153+
for role in [self.__assistant_role, self.__user_role]:
154+
message = role.generate_reply(message)
155+
if self.__is_session_completed():
156+
break
82157
except Exception as e:
83158
logging.error(e)
159+
finally:
160+
self.run_count = 0
84161

85-
if len(agents_result) == 0:
86-
return dict()
87-
88-
if len(agents_result) == 1:
89-
final_result = self.__summariser.run_sync(
90-
"From the actions taken by the assistant. Please give me the result.",
91-
message_history=next(v for _, v in agents_result.items()).all_messages(),
92-
)
93-
else:
94-
agent_summaries = []
95-
for agent_result in agents_result.values():
96-
agent_summary_result = self.__summariser.run_sync(
97-
"From the actions taken by the assistant. Please give me the result.",
98-
message_history=agent_result.all_messages(),
99-
)
100-
agent_summary = getattr(agent_summary_result.data, _MESSAGE_ATTRIBUTE, None)
101-
if agent_summary is None:
102-
continue
103-
104-
agent_summaries.append(agent_summary)
105-
agent_summary_list = "\n* " + "\n* ".join(agent_summaries)
106-
final_result = self.__summariser.run_sync(
107-
"Please give me the result from the following summary of what the assistants have done."
108-
+ agent_summary_list,
109-
)
162+
@property
163+
def history(self):
164+
return self.__user_role.history
110165

111-
return final_result.data.dict()
166+
@property
167+
def tool_records(self):
168+
for tool in self.tool_set.values():
169+
if isinstance(tool, CodeEditTool):
170+
cwd = Path.cwd()
171+
modified_files = [file_path.relative_to(cwd) for file_path in tool.tool_records["modified_files"]]
172+
return [dict(path=str(file)) for file in modified_files]
173+
return []

patchwork/patchflows/LogAnalysis/LogAnalysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from patchwork.common.utils.step_typing import validate_steps_with_inputs
88
from patchwork.logger import logger
99
from patchwork.step import Step
10-
from patchwork.steps import AgenticLLM
10+
from patchwork.steps import AgenticLLMV2
1111

1212
_DEFAULT_INPUT_FILE = Path(__file__).parent / "defaults.yml"
1313

@@ -16,7 +16,7 @@ class LogAnalysis(Step):
1616
def __init__(self, inputs: dict):
1717
PatchflowProgressBar(self).register_steps(
1818
# CallSQL,
19-
AgenticLLM,
19+
AgenticLLMV2,
2020
)
2121
final_inputs = yaml.safe_load(_DEFAULT_INPUT_FILE.read_text()) or dict()
2222
final_inputs.update(inputs)
@@ -37,7 +37,7 @@ def run(self) -> dict:
3737
sentry_filename = "sentry_issues.json"
3838
for i in range(self.inputs.get("analysis_limit") or 5):
3939
# for i in range(self.inputs.get("log_finding_limit") or sys.maxsize):
40-
logs_detection_output = AgenticLLM(
40+
logs_detection_output = AgenticLLMV2(
4141
dict(
4242
max_agent_calls=5,
4343
agent_system_prompt="""\
@@ -67,7 +67,7 @@ def run(self) -> dict:
6767
)
6868
).run()
6969

70-
analysis_output = AgenticLLM(
70+
analysis_output = AgenticLLMV2(
7171
dict(
7272
max_agent_calls=5,
7373
prompt_value=logs_detection_output,
Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from pathlib import Path
22

3-
from patchwork.common.multiturn_strategy.agentic_strategy import (
4-
AgentConfig,
5-
AgenticStrategy,
6-
)
3+
from patchwork.common.client.llm.aio import AioLlmClient
4+
from patchwork.common.multiturn_strategy.agentic_strategy import AgenticStrategy
75
from patchwork.common.tools import Tool
86
from patchwork.step import Step
97
from patchwork.steps.AgenticLLM.typed import AgenticLLMInputs, AgenticLLMOutputs
@@ -15,21 +13,18 @@ def __init__(self, inputs):
1513
base_path = inputs.get("base_path")
1614
if base_path is None:
1715
base_path = str(Path.cwd())
18-
self.conversation_limit = int(inputs.get("max_agent_calls", 1))
16+
self.conversation_limit = int(int(inputs.get("max_llm_calls", 2)) / 2)
1917
self.agentic_strategy = AgenticStrategy(
20-
api_key=inputs.get("anthropic_api_key"),
21-
template_data=inputs.get("prompt_value", {}),
22-
system_prompt_template=inputs.get("system_prompt", "Summarise from our previous conversation"),
18+
llm_client=AioLlmClient.create_aio_client(inputs),
19+
tool_set=Tool.get_tools(path=base_path),
20+
template_data=inputs.get("prompt_value"),
21+
system_prompt_template=inputs.get("system_prompt"),
2322
user_prompt_template=inputs.get("user_prompt"),
24-
agent_configs=[
25-
AgentConfig(
26-
name="Assistant",
27-
tool_set=Tool.get_tools(path=base_path),
28-
system_prompt=inputs.get("agent_system_prompt"),
29-
)
30-
],
31-
example_json=inputs.get("example_json"),
3223
)
3324

3425
def run(self) -> dict:
35-
return self.agentic_strategy.execute(limit=self.conversation_limit)
26+
self.agentic_strategy.execute(limit=self.conversation_limit)
27+
return dict(
28+
conversation_history=self.agentic_strategy.history,
29+
tool_records=self.agentic_strategy.tool_records,
30+
)
Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing_extensions import Annotated, Any, Dict, TypedDict
1+
from typing_extensions import Annotated, Any, Dict, List, TypedDict
22

33
from patchwork.common.utils.step_typing import StepTypeConfig
4+
from patchwork.steps.CallLLM.CallLLM import TOKEN_URL
45

56

67
class AgenticLLMInputs(TypedDict, total=False):
@@ -9,10 +10,33 @@ class AgenticLLMInputs(TypedDict, total=False):
910
system_prompt: str
1011
user_prompt: str
1112
max_llm_calls: Annotated[int, StepTypeConfig(is_config=True)]
12-
anthropic_api_key: str
13-
agent_system_prompt: str
14-
example_json: str
13+
openai_api_key: Annotated[
14+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "anthropic_api_key"])
15+
]
16+
anthropic_api_key: Annotated[
17+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "openai_api_key"])
18+
]
19+
patched_api_key: Annotated[
20+
str,
21+
StepTypeConfig(
22+
is_config=True,
23+
or_op=["openai_api_key", "google_api_key", "anthropic_api_key"],
24+
msg=f"""\
25+
Model API key not found.
26+
Please login at: "{TOKEN_URL}"
27+
Please go to the Integration's tab and generate an API key.
28+
Please copy the access token that is generated, and add `--patched_api_key=<token>` to the command line.
29+
30+
If you are using a OpenAI API Key, please set `--openai_api_key=<token>`.""",
31+
),
32+
]
33+
google_api_key: Annotated[
34+
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
35+
]
1536

1637

1738
class AgenticLLMOutputs(TypedDict):
18-
pass
39+
conversation_history: List[Dict]
40+
tool_records: List[Dict]
41+
# request_tokens: int
42+
# response_tokens: int

0 commit comments

Comments
 (0)