Skip to content

Commit ce99726

Browse files
committed
Add database agent
1 parent b230bdd commit ce99726

File tree

18 files changed

+472
-37
lines changed

18 files changed

+472
-37
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import asyncio
2+
from functools import partial
3+
from typing import Any, Optional, Union
4+
5+
from pydantic import BaseModel
6+
from pydantic_ai import Agent
7+
from pydantic_ai.agent import AgentRunResult
8+
9+
from patchwork.common.client.llm.protocol import LlmClient
10+
from patchwork.common.client.llm.utils import example_json_to_base_model
11+
from patchwork.common.tools import Tool
12+
13+
14+
class StepCompletedResult(BaseModel):
15+
is_step_completed: bool
16+
17+
18+
class PlanCompletedResult(BaseModel):
19+
is_plan_completed: bool
20+
21+
22+
class ExecutionResult(BaseModel):
23+
json_data: str
24+
message: str
25+
is_completed: bool
26+
27+
28+
class _Plan:
29+
def __init__(self, initial_plan: Optional[list[str]] = None):
30+
self.__plan = initial_plan or []
31+
self.__cursor = 0
32+
33+
def advance(self) -> bool:
34+
self.__cursor += 1
35+
return self.__cursor < len(self.__plan)
36+
37+
def is_empty(self) -> bool:
38+
return len(self.__plan) == 0
39+
40+
def register_steps(self, agent: Agent):
41+
agent.tool_plain(self.get_current_plan)
42+
agent.tool_plain(self.get_current_step)
43+
agent.tool_plain(self.get_current_step_index)
44+
agent.tool_plain(self.add_step)
45+
agent.tool_plain(self.delete_step)
46+
47+
def get_current_plan(self) -> str:
48+
return "\n".join([f"{i}. {step}" for i, step in enumerate(self.__plan)])
49+
50+
def get_current_step(self) -> str:
51+
if len(self.__plan) == 0:
52+
return "There is currently no plan"
53+
54+
return self.__plan[self.__cursor]
55+
56+
def get_current_step_index(self) -> int:
57+
return self.__cursor
58+
59+
def add_step(self, index: int, step: str) -> str:
60+
if index < 0:
61+
return "index cannot be a negative number"
62+
63+
if index >= len(self.__plan):
64+
insertion_func = self.__plan.append
65+
else:
66+
insertion_func = partial(self.__plan.insert, index)
67+
68+
insertion_func(step)
69+
return "Added step\nCurrent plan:\n" + self.get_current_plan()
70+
71+
def delete_step(self, step: str) -> str:
72+
try:
73+
i = self.__plan.index(step)
74+
self.__plan.pop(i)
75+
return self.get_current_plan()
76+
except ValueError:
77+
return "Step not found in plan\nCurrent plan:\n" + self.get_current_plan()
78+
79+
80+
class PlanningStrategy:
81+
def __init__(
82+
self,
83+
llm_client: LlmClient,
84+
planner_system_prompt: str,
85+
executor_system_prompt: str,
86+
executor_tool_set: dict[str, Tool],
87+
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
88+
):
89+
self.planner = Agent(
90+
llm_client,
91+
name="Planner",
92+
system_prompt=planner_system_prompt,
93+
model_settings=dict(
94+
parallel_tool_calls=False,
95+
model="gemini-2.0-flash",
96+
),
97+
)
98+
99+
self.plan = _Plan()
100+
self.plan.register_steps(self.planner)
101+
102+
self.executor = Agent(
103+
llm_client,
104+
name="Executor",
105+
system_prompt=executor_system_prompt,
106+
result_type=ExecutionResult,
107+
tools=[tool.to_pydantic_ai_function_tool() for tool in executor_tool_set.values()],
108+
model_settings=dict(
109+
parallel_tool_calls=False,
110+
model="gemini-2.0-flash",
111+
),
112+
)
113+
114+
self.__summariser = Agent(
115+
llm_client,
116+
result_retries=5,
117+
system_prompt="""\
118+
Please summarise the conversation given and provide the result in the structure that is asked of you.
119+
""",
120+
result_type=example_json_to_base_model(example_json),
121+
model_settings=dict(
122+
parallel_tool_calls=False,
123+
model="gemini-2.0-flash",
124+
),
125+
)
126+
127+
self.reset()
128+
129+
def reset(self):
130+
self.__request_tokens = 0
131+
self.__response_tokens = 0
132+
133+
def usage(self):
134+
return {
135+
"request_tokens": self.__request_tokens,
136+
"response_tokens": self.__response_tokens,
137+
}
138+
139+
def __agent_run(self, agent: Agent, prompt: str, **kwargs) -> AgentRunResult[Any]:
140+
planner_response = agent.run_sync(prompt, **kwargs)
141+
self.__request_tokens += planner_response.usage().request_tokens
142+
self.__response_tokens += planner_response.usage().response_tokens
143+
return planner_response
144+
145+
def run(self, task: str, conversation_limit: int = 10) -> dict:
146+
loop = asyncio.new_event_loop()
147+
148+
planner_response = self.__agent_run(self.planner, f"Produce the initial plan for {task}")
149+
planner_history = planner_response.all_messages()
150+
if self.plan.is_empty():
151+
planner_response = self.__agent_run(
152+
self.planner, f"Please use the tools provided to setup the plan", message_history=planner_history
153+
)
154+
planner_history = planner_response.all_messages()
155+
156+
for i in range(conversation_limit):
157+
step = self.plan.get_current_step()
158+
executor_prompt = f"Please execute the following task: {step}"
159+
response = self.__agent_run(self.executor, executor_prompt)
160+
161+
plan_str = self.plan.get_current_plan()
162+
step_index = self.plan.get_current_step_index()
163+
planner_prompt = f"""\
164+
The current plan is:
165+
{plan_str}
166+
167+
We are current at {step_index}.
168+
If the current step is not completed, edit the current step.
169+
170+
The execution result for the step {step_index} is:
171+
{response.data}
172+
173+
"""
174+
planner_response = self.__agent_run(
175+
self.planner,
176+
planner_prompt,
177+
message_history=planner_history,
178+
result_type=StepCompletedResult,
179+
)
180+
planner_history = planner_response.all_messages()
181+
if not planner_response.data.is_step_completed:
182+
continue
183+
184+
if self.plan.advance():
185+
continue
186+
187+
planner_response = self.__agent_run(
188+
self.planner,
189+
"Is the task completed? If the task is not completed please add more steps using the tools provided.",
190+
message_history=planner_history,
191+
result_type=PlanCompletedResult,
192+
)
193+
if planner_response.data.is_plan_completed:
194+
break
195+
196+
final_result = self.__agent_run(
197+
self.__summariser,
198+
"From the actions taken by the assistant. Please give me the result.",
199+
message_history=planner_history,
200+
)
201+
202+
loop.close()
203+
return final_result.data.dict()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from typing_extensions import Any, Union
2+
3+
from patchwork.common.tools import Tool
4+
from patchwork.steps import CallSQL
5+
6+
7+
class DatabaseQueryTool(Tool, tool_name="db_query_tool"):
8+
def __init__(self, inputs: dict[str, Any]):
9+
super().__init__()
10+
self.db_settings = inputs.copy()
11+
12+
@property
13+
def json_schema(self) -> dict:
14+
return {
15+
"name": "db_query_tool",
16+
"description": """\
17+
Run SQL Query on current database.
18+
""",
19+
"input_schema": {
20+
"type": "object",
21+
"properties": {
22+
"query": {
23+
"type": "string",
24+
"description": "Database query to run.",
25+
}
26+
},
27+
"required": ["query"],
28+
},
29+
}
30+
31+
def execute(self, query: str) -> Union[list[dict[str, Any]], str]:
32+
db_settings = self.db_settings.copy()
33+
db_settings["db_query"] = query
34+
try:
35+
return CallSQL(db_settings).run().get("results", [])
36+
except Exception as e:
37+
return str(e)

patchwork/steps/AgenticLLM/typed.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ class AgenticLLMInputs(TypedDict, total=False):
1111
user_prompt: str
1212
max_llm_calls: Annotated[int, StepTypeConfig(is_config=True)]
1313
openai_api_key: Annotated[
14-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
14+
str,
15+
StepTypeConfig(
16+
is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"]
17+
),
1518
]
1619
anthropic_api_key: Annotated[
17-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
20+
str,
21+
StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"]),
1822
]
1923
patched_api_key: Annotated[
2024
str,
@@ -31,10 +35,16 @@ class AgenticLLMInputs(TypedDict, total=False):
3135
),
3236
]
3337
google_api_key: Annotated[
34-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
38+
str,
39+
StepTypeConfig(
40+
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"]
41+
),
3542
]
3643
client_is_gcp: Annotated[
37-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
44+
str,
45+
StepTypeConfig(
46+
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"]
47+
),
3848
]
3949

4050

patchwork/steps/CallLLM/typed.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ class CallLLMInputs(TypedDict, total=False):
1313
model_args: Annotated[str, StepTypeConfig(is_config=True)]
1414
client_args: Annotated[str, StepTypeConfig(is_config=True)]
1515
openai_api_key: Annotated[
16-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
16+
str,
17+
StepTypeConfig(
18+
is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"]
19+
),
1720
]
1821
anthropic_api_key: Annotated[
19-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
22+
str,
23+
StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"]),
2024
]
2125
patched_api_key: Annotated[
2226
str,
@@ -33,10 +37,16 @@ class CallLLMInputs(TypedDict, total=False):
3337
),
3438
]
3539
google_api_key: Annotated[
36-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
40+
str,
41+
StepTypeConfig(
42+
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"]
43+
),
3744
]
3845
client_is_gcp: Annotated[
39-
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
46+
str,
47+
StepTypeConfig(
48+
is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"]
49+
),
4050
]
4151
file: Annotated[str, StepTypeConfig(is_path=True)]
4252

0 commit comments

Comments
 (0)