Skip to content

Commit 9569d0a

Browse files
committed
Updated agent
1 parent ee820b7 commit 9569d0a

File tree

8 files changed

+87
-39
lines changed

8 files changed

+87
-39
lines changed

text_2_sql/autogen/agentic_text_2_sql.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"source": [
99
"import dotenv\n",
1010
"import logging\n",
11-
"from agentic_text_2_sql import text_2_sql_generator, text_2_sql_cache_updater"
11+
"from autogen_agentchat.task import Console\n",
12+
"from agentic_text_2_sql import text_2_sql_generator"
1213
]
1314
},
1415
{
@@ -35,7 +36,7 @@
3536
"metadata": {},
3637
"outputs": [],
3738
"source": [
38-
"result = await text_2_sql_generator.run(task=\"What are the total number of sales within 2008?\")"
39+
"result = text_2_sql_generator.run_stream(task=\"What are the total number of sales within 2008?\")"
3940
]
4041
},
4142
{
@@ -44,7 +45,7 @@
4445
"metadata": {},
4546
"outputs": [],
4647
"source": [
47-
"print(result.messages[-1].content)"
48+
"await Console(result)"
4849
]
4950
},
5051
{

text_2_sql/autogen/agentic_text_2_sql.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from autogen_agentchat.teams import SelectorGroupChat
33
from utils.models import MINI_MODEL
44
from utils.llm_agent_creator import LLMAgentCreator
5-
from autogen_core.components.models import FunctionExecutionResult
65
import logging
6+
from custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
7+
import json
78

89
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
910
"sql_query_generation_agent",
@@ -16,39 +17,47 @@
1617
target_engine="Microsoft SQL Server",
1718
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
1819
)
19-
SQL_QUERY_CACHE_AGENT = LLMAgentCreator.create("sql_query_cache_agent")
20+
SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent()
2021
ANSWER_AGENT = LLMAgentCreator.create("answer_agent")
2122
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create("question_decomposition_agent")
2223

2324

2425
def text_2_sql_generator_selector_func(messages):
2526
logging.info("Messages: %s", messages)
27+
decision = None # Initialize decision variable
28+
2629
if len(messages) == 1:
27-
return "sql_query_cache_agent"
30+
decision = "sql_query_cache_agent"
2831

2932
elif (
3033
messages[-1].source == "sql_query_cache_agent"
31-
and isinstance(messages[-1].content, FunctionExecutionResult)
32-
and messages[-1].content.content is not None
34+
and messages[-1].content is not None
3335
):
34-
return "sql_query_correction_agent"
36+
cache_result = json.loads(messages[-1].content)
37+
if cache_result.get("cached_questions_and_schemas") is not None:
38+
decision = "sql_query_correction_agent"
39+
else:
40+
decision = "sql_schema_selection_agent"
3541

3642
elif messages[-1].source == "question_decomposition_agent":
37-
return "sql_schema_selection_agent"
43+
decision = "sql_schema_selection_agent"
3844

3945
elif messages[-1].source == "sql_schema_selection_agent":
40-
return "sql_query_generation_agent"
46+
decision = "sql_query_generation_agent"
4147

4248
elif (
4349
messages[-1].source == "sql_query_correction_agent"
4450
and messages[-1].content == "VALIDATED"
4551
):
46-
return "answer_agent"
52+
decision = "answer_agent"
4753

4854
elif messages[-1].source == "sql_query_correction_agent":
49-
return "sql_query_correction_agent"
55+
decision = "sql_query_correction_agent"
56+
57+
# Log the decision
58+
logging.info("Decision: %s", decision)
5059

51-
return None
60+
return decision
5261

5362

5463
termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10)
@@ -61,6 +70,7 @@ def text_2_sql_generator_selector_func(messages):
6170
ANSWER_AGENT,
6271
QUESTION_DECOMPOSITION_AGENT,
6372
],
73+
allow_repeated_speaker=False,
6474
model_client=MINI_MODEL,
6575
termination_condition=termination,
6676
selector_func=text_2_sql_generator_selector_func,

text_2_sql/autogen/custom_agents/__init__.py

Whitespace-only changes.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import AsyncGenerator, List, Sequence
2+
3+
from autogen_agentchat.agents import BaseChatAgent
4+
from autogen_agentchat.base import Response
5+
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
6+
from autogen_core.base import CancellationToken
7+
from utils.sql_utils import fetch_queries_from_cache
8+
import json
9+
import logging
10+
11+
12+
class SqlQueryCacheAgent(BaseChatAgent):
13+
def __init__(self):
14+
super().__init__(
15+
"sql_query_cache_agent",
16+
"An agent that fetches the queries from the cache based on the user question.",
17+
)
18+
19+
@property
20+
def produced_message_types(self) -> List[type[ChatMessage]]:
21+
return [TextMessage]
22+
23+
async def on_messages(
24+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
25+
) -> Response:
26+
# Calls the on_messages_stream.
27+
response: Response | None = None
28+
async for message in self.on_messages_stream(messages, cancellation_token):
29+
if isinstance(message, Response):
30+
response = message
31+
assert response is not None
32+
return response
33+
34+
async def on_messages_stream(
35+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
36+
) -> AsyncGenerator[AgentMessage | Response, None]:
37+
user_question = messages[0].content
38+
39+
# Fetch the queries from the cache based on the user question.
40+
logging.info("Fetching queries from cache based on the user question...")
41+
42+
cached_queries = await fetch_queries_from_cache(user_question)
43+
44+
yield Response(
45+
chat_message=TextMessage(
46+
content=json.dumps(cached_queries), source=self.name
47+
)
48+
)
49+
50+
async def on_reset(self, cancellation_token: CancellationToken) -> None:
51+
pass

text_2_sql/autogen/llm_agents/sql_query_cache_agent.yaml

Lines changed: 0 additions & 12 deletions
This file was deleted.

text_2_sql/autogen/llm_agents/sql_schema_selection_agent.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ description:
77
system_message:
88
"You are a helpful AI Assistant that specialises in selecting relevant SQL schemas to answer a given user's question.
99
10-
Use the tools available to you to select the correct schemas that will help. Extract key terms from the user's question and use them to search for the correct schema."
10+
Use the tools available to you to select the correct schemas that will help. Extract key terms from the user's question and use them to search for the correct schema.
11+
12+
Limit the number of calls to the 'sql_get_entity_schemas_tool' tool to avoid unnecessary calls.
13+
14+
If you are unsure about the schema, you can ask the user for more information or ask for clarification."
1115
tools:
1216
- sql_get_entity_schemas_tool

text_2_sql/autogen/utils/llm_agent_creator.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import yaml
22
from autogen_core.components.tools import FunctionTool
3-
from autogen_agentchat.agents import ToolUseAssistantAgent
3+
from autogen_agentchat.agents import AssistantAgent
44
from utils.sql_utils import (
55
query_execution,
66
get_entity_schemas,
7-
fetch_queries_from_cache,
87
)
98
from utils.models import MINI_MODEL
109
from jinja2 import Template
@@ -37,11 +36,6 @@ def get_tool(cls, tool_name):
3736
get_entity_schemas,
3837
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
3938
)
40-
elif tool_name == "sql_query_cache_tool":
41-
return FunctionTool(
42-
fetch_queries_from_cache,
43-
description="Fetch the pre-assembled queries, and potential results from the cache based on the user's question.",
44-
)
4539
else:
4640
raise ValueError(f"Tool {tool_name} not found")
4741

@@ -62,9 +56,9 @@ def create(cls, name: str, **kwargs):
6256
for tool in agent_file["tools"]:
6357
tools.append(cls.get_tool(tool))
6458

65-
agent = ToolUseAssistantAgent(
59+
agent = AssistantAgent(
6660
name=name,
67-
registered_tools=tools,
61+
tools=tools,
6862
model_client=cls.get_model(agent_file["model"]),
6963
description=cls.get_property_and_render_parameters(
7064
agent_file, "description", kwargs

text_2_sql/autogen/utils/sql_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async def fetch_queries_from_cache(question: str) -> str:
112112
)
113113

114114
if len(cached_schemas) == 0:
115-
return None
115+
return {"cached_questions_and_schemas": None}
116116

117117
logging.info("Cached schemas: %s", cached_schemas)
118118
if PRE_RUN_QUERY_CACHE and len(cached_schemas) > 0:
@@ -139,6 +139,6 @@ async def fetch_queries_from_cache(question: str) -> str:
139139
"schemas": sql_query["Schemas"],
140140
}
141141

142-
return query_result_store
142+
return {"cached_questions_and_schemas": query_result_store}
143143

144-
return {"cached_questions": cached_schemas}
144+
return {"cached_questions_and_schemas": cached_schemas}

0 commit comments

Comments
 (0)