|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
| 4 | +import sys |
| 5 | + |
| 6 | +from pydantic import BaseModel |
| 7 | +from pydantic_ai import Agent |
| 8 | +from pydantic_ai.models.anthropic import AnthropicModel |
| 9 | +from typing_extensions import Any, Dict, Optional, Union |
| 10 | + |
| 11 | +from patchwork.common.client.llm.utils import example_json_to_base_model |
| 12 | +from patchwork.common.tools import Tool |
| 13 | +from patchwork.common.utils.utils import mustache_render |
| 14 | + |
| 15 | +_COMPLETION_FLAG_ATTRIBUTE = "is_task_completed" |
| 16 | +_MESSAGE_ATTRIBUTE = "message" |
| 17 | + |
| 18 | + |
| 19 | +class AgentConfig(BaseModel): |
| 20 | + class Config: |
| 21 | + arbitrary_types_allowed = True |
| 22 | + |
| 23 | + name: str |
| 24 | + tool_set: Dict[str, Tool] |
| 25 | + system_prompt: str = "" |
| 26 | + example_json: Union[ |
| 27 | + str, Dict[str, Any] |
| 28 | + ] = f'{{"{_MESSAGE_ATTRIBUTE}":"message", "{_COMPLETION_FLAG_ATTRIBUTE}": false}}' |
| 29 | + |
| 30 | + |
| 31 | +class AgenticStrategyV2: |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + api_key: str, |
| 35 | + template_data: dict[str, str], |
| 36 | + system_prompt_template: str, |
| 37 | + user_prompt_template: str, |
| 38 | + agent_configs: list[AgentConfig], |
| 39 | + example_json: Union[str, dict[str, Any]] = '{"output":"output text"}', |
| 40 | + limit: Optional[int] = None, |
| 41 | + ): |
| 42 | + self.__limit = limit |
| 43 | + self.__template_data = template_data |
| 44 | + self.__user_prompt_template = user_prompt_template |
| 45 | + model = AnthropicModel("claude-3-5-sonnet-latest", api_key=api_key) |
| 46 | + self.__summariser = Agent( |
| 47 | + model, |
| 48 | + system_prompt=mustache_render(system_prompt_template, self.__template_data), |
| 49 | + result_type=example_json_to_base_model(example_json), |
| 50 | + model_settings=dict(parallel_tool_calls=False), |
| 51 | + ) |
| 52 | + self.__agents = [] |
| 53 | + for agent_config in agent_configs: |
| 54 | + tools = [] |
| 55 | + for tool in agent_config.tool_set.values(): |
| 56 | + tools.append(tool.to_pydantic_ai_function_tool()) |
| 57 | + agent = Agent( |
| 58 | + model, |
| 59 | + name=agent_config.name, |
| 60 | + system_prompt=mustache_render(agent_config.system_prompt, self.__template_data), |
| 61 | + tools=tools, |
| 62 | + result_type=example_json_to_base_model(agent_config.example_json), |
| 63 | + model_settings=dict(parallel_tool_calls=False), |
| 64 | + ) |
| 65 | + |
| 66 | + self.__agents.append(agent) |
| 67 | + |
| 68 | + def execute(self, limit: Optional[int] = None) -> dict: |
| 69 | + agents_result = dict() |
| 70 | + try: |
| 71 | + for index, agent in enumerate(self.__agents): |
| 72 | + user_message = mustache_render(self.__user_prompt_template, self.__template_data) |
| 73 | + message_history = None |
| 74 | + agent_output = None |
| 75 | + for i in range(limit or self.__limit or sys.maxsize): |
| 76 | + agent_output = agent.run_sync(user_message, message_history=message_history) |
| 77 | + message_history = agent_output.all_messages() |
| 78 | + if getattr(agent_output.data, _COMPLETION_FLAG_ATTRIBUTE, False): |
| 79 | + break |
| 80 | + user_message = "Please continue" |
| 81 | + agents_result[index] = agent_output |
| 82 | + except Exception as e: |
| 83 | + logging.error(e) |
| 84 | + |
| 85 | + if len(agents_result) == 0: |
| 86 | + return dict() |
| 87 | + |
| 88 | + if len(agents_result) == 1: |
| 89 | + final_result = self.__summariser.run_sync( |
| 90 | + "From the actions taken by the assistant. Please give me the result.", |
| 91 | + message_history=next(v for _, v in agents_result.items()).all_messages(), |
| 92 | + ) |
| 93 | + else: |
| 94 | + agent_summaries = [] |
| 95 | + for agent_result in agents_result.values(): |
| 96 | + agent_summary_result = self.__summariser.run_sync( |
| 97 | + "From the actions taken by the assistant. Please give me the result.", |
| 98 | + message_history=agent_result.all_messages(), |
| 99 | + ) |
| 100 | + agent_summary = getattr(agent_summary_result.data, _MESSAGE_ATTRIBUTE, None) |
| 101 | + if agent_summary is None: |
| 102 | + continue |
| 103 | + |
| 104 | + agent_summaries.append(agent_summary) |
| 105 | + agent_summary_list = "\n* " + "\n* ".join(agent_summaries) |
| 106 | + final_result = self.__summariser.run_sync( |
| 107 | + "Please give me the result from the following summary of what the assistants have done." |
| 108 | + + agent_summary_list, |
| 109 | + ) |
| 110 | + |
| 111 | + return final_result.data.dict() |
0 commit comments