Skip to content

Commit ee1c514

Browse files
committed
add forgotten v2
1 parent fabd099 commit ee1c514

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import sys
5+
6+
from pydantic import BaseModel
7+
from pydantic_ai import Agent
8+
from pydantic_ai.models.anthropic import AnthropicModel
9+
from typing_extensions import Any, Dict, Optional, Union
10+
11+
from patchwork.common.client.llm.utils import example_json_to_base_model
12+
from patchwork.common.tools import Tool
13+
from patchwork.common.utils.utils import mustache_render
14+
15+
_COMPLETION_FLAG_ATTRIBUTE = "is_task_completed"
16+
_MESSAGE_ATTRIBUTE = "message"
17+
18+
19+
class AgentConfig(BaseModel):
20+
class Config:
21+
arbitrary_types_allowed = True
22+
23+
name: str
24+
tool_set: Dict[str, Tool]
25+
system_prompt: str = ""
26+
example_json: Union[
27+
str, Dict[str, Any]
28+
] = f'{{"{_MESSAGE_ATTRIBUTE}":"message", "{_COMPLETION_FLAG_ATTRIBUTE}": false}}'
29+
30+
31+
class AgenticStrategyV2:
32+
def __init__(
33+
self,
34+
api_key: str,
35+
template_data: dict[str, str],
36+
system_prompt_template: str,
37+
user_prompt_template: str,
38+
agent_configs: list[AgentConfig],
39+
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
40+
limit: Optional[int] = None,
41+
):
42+
self.__limit = limit
43+
self.__template_data = template_data
44+
self.__user_prompt_template = user_prompt_template
45+
model = AnthropicModel("claude-3-5-sonnet-latest", api_key=api_key)
46+
self.__summariser = Agent(
47+
model,
48+
system_prompt=mustache_render(system_prompt_template, self.__template_data),
49+
result_type=example_json_to_base_model(example_json),
50+
model_settings=dict(parallel_tool_calls=False),
51+
)
52+
self.__agents = []
53+
for agent_config in agent_configs:
54+
tools = []
55+
for tool in agent_config.tool_set.values():
56+
tools.append(tool.to_pydantic_ai_function_tool())
57+
agent = Agent(
58+
model,
59+
name=agent_config.name,
60+
system_prompt=mustache_render(agent_config.system_prompt, self.__template_data),
61+
tools=tools,
62+
result_type=example_json_to_base_model(agent_config.example_json),
63+
model_settings=dict(parallel_tool_calls=False),
64+
)
65+
66+
self.__agents.append(agent)
67+
68+
def execute(self, limit: Optional[int] = None) -> dict:
69+
agents_result = dict()
70+
try:
71+
for index, agent in enumerate(self.__agents):
72+
user_message = mustache_render(self.__user_prompt_template, self.__template_data)
73+
message_history = None
74+
agent_output = None
75+
for i in range(limit or self.__limit or sys.maxsize):
76+
agent_output = agent.run_sync(user_message, message_history=message_history)
77+
message_history = agent_output.all_messages()
78+
if getattr(agent_output.data, _COMPLETION_FLAG_ATTRIBUTE, False):
79+
break
80+
user_message = "Please continue"
81+
agents_result[index] = agent_output
82+
except Exception as e:
83+
logging.error(e)
84+
85+
if len(agents_result) == 0:
86+
return dict()
87+
88+
if len(agents_result) == 1:
89+
final_result = self.__summariser.run_sync(
90+
"From the actions taken by the assistant. Please give me the result.",
91+
message_history=next(v for _, v in agents_result.items()).all_messages(),
92+
)
93+
else:
94+
agent_summaries = []
95+
for agent_result in agents_result.values():
96+
agent_summary_result = self.__summariser.run_sync(
97+
"From the actions taken by the assistant. Please give me the result.",
98+
message_history=agent_result.all_messages(),
99+
)
100+
agent_summary = getattr(agent_summary_result.data, _MESSAGE_ATTRIBUTE, None)
101+
if agent_summary is None:
102+
continue
103+
104+
agent_summaries.append(agent_summary)
105+
agent_summary_list = "\n* " + "\n* ".join(agent_summaries)
106+
final_result = self.__summariser.run_sync(
107+
"Please give me the result from the following summary of what the assistants have done."
108+
+ agent_summary_list,
109+
)
110+
111+
return final_result.data.dict()

0 commit comments

Comments
 (0)