Skip to content

Commit 3777e3e

Browse files
committed
Work work on agentic
1 parent 8b1d471 commit 3777e3e

File tree

8 files changed

+231
-29
lines changed

8 files changed

+231
-29
lines changed

text_2_sql/autogen/agents.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from autogen_agentchat.agents import ToolUseAssistantAgent
2+
from utils.models import MINI_MODEL
3+
4+
ANSWER_AGENT = ToolUseAssistantAgent(
5+
name="Answer_Revision_Agent",
6+
registered_tools=[],
7+
model_client=MINI_MODEL,
8+
description="An agent that takes the user's question, the outputs from the SQL queries to provide an answer to the user's question.",
9+
system_message="You are a helpful AI assistant. Take the user's question and the outputs from the SQL queries to provide an answer to the user's question.",
10+
)
11+
12+
QUERY_DECOMPOSITION_AGENT = ToolUseAssistantAgent(
13+
name="Query_Decomposition_Agent",
14+
registered_tools=[],
15+
model_client=MINI_MODEL,
16+
description="An agent that will decompose the user's question into smaller parts to be used in the SQL queries. Use this agent when the user's question is too complex to be answered in one SQL query.",
17+
system_message="You are a helpful AI assistant. Decompose the user's question into smaller parts to be used in the SQL queries. Use this agent when the user's question is too complex to be answered in one SQL query.",
18+
)
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
autogen-core
12
autogen-agentchat
2-
autogen-ext[openai]
3+
autogen-ext[openai,azure]
34
aioodbc
45
azure-search
56
azure-search-documents==11.6.0b5
67
azure-identity
78
python-dotenv
9+
openai

text_2_sql/autogen/sql_tools.py

Lines changed: 0 additions & 28 deletions
This file was deleted.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from autogen_core.components.tools import FunctionTool
2+
from autogen_agentchat.agents import ToolUseAssistantAgent
3+
from utils.sql_utils import (
4+
query_execution,
5+
get_entity_schemas,
6+
fetch_queries_from_cache,
7+
)
8+
from utils.models import MINI_MODEL
9+
10+
SQL_QUERY_EXECUTION_TOOL = FunctionTool(
11+
query_execution,
12+
description="Runs an SQL query against the SQL Database to extract information",
13+
)
14+
15+
SQL_GET_ENTITY_SCHEMAS_TOOL = FunctionTool(
16+
get_entity_schemas,
17+
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
18+
)
19+
20+
SQL_QUERY_CACHE_TOOLS = FunctionTool(
21+
fetch_queries_from_cache,
22+
description="Fetch the pre-assembled queries, and potential results from the cache based on the user's question.",
23+
)
24+
25+
SQL_QUERY_AGENT = ToolUseAssistantAgent(
26+
name="SQL_Query_Agent",
27+
registered_tools=[SQL_QUERY_EXECUTION_TOOL],
28+
model_client=MINI_MODEL,
29+
description="An agent that can take a user's question and run an SQL query against the SQL Database to extract information",
30+
system_message="You are a helpful AI assistant. Solve tasks using your tools. Specifically, you can take into consideration the user's request and run an SQL query against the SQL Database to extract information.",
31+
)
32+
33+
SQL_SCHEMA_EXTRACTION_AGENT = ToolUseAssistantAgent(
34+
name="SQL_Schema_Extraction_Agent",
35+
registered_tools=[SQL_GET_ENTITY_SCHEMAS_TOOL],
36+
model_client=MINI_MODEL,
37+
description="An agent that can take a user's question and extract the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term",
38+
system_message="You are a helpful AI assistant. Solve tasks using your tools. Specifically, you can take into consideration the user's request and extract the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term.",
39+
)
40+
41+
SQL_QUERY_CORRECTION_AGENT = ToolUseAssistantAgent(
42+
name="SQL_Query_Correction_Agent",
43+
registered_tools=[SQL_QUERY_EXECUTION_TOOL],
44+
model_client=MINI_MODEL,
45+
description="An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query",
46+
system_message="",
47+
)
48+
49+
SQL_QUERY_CACHE_AGENT = ToolUseAssistantAgent(
50+
name="SQL_Query_Cache_Agent",
51+
registered_tools=[SQL_QUERY_CACHE_TOOLS],
52+
model_client=MINI_MODEL,
53+
description="An agent that will fetch the queries from the cache based on the user's question.",
54+
system_message="",
55+
)
File renamed without changes.

text_2_sql/autogen/utils/models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from autogen_ext.models import AzureOpenAIChatCompletionClient
2+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
3+
import os
4+
5+
# Create the token provider
6+
token_provider = get_bearer_token_provider(
7+
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
8+
)
9+
10+
MINI_MODEL = AzureOpenAIChatCompletionClient(
11+
model="{your-azure-deployment}",
12+
api_version="2024-06-01",
13+
azure_endpoint=os.environ["OpenAI__Endpoint"],
14+
# Optional if you choose key-based authentication.
15+
azure_ad_token_provider=token_provider,
16+
# api_key="sk-...", # For key-based authentication.
17+
model_capabilities={
18+
"vision": False,
19+
"function_calling": True,
20+
"json_output": True,
21+
},
22+
)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import logging
2+
import os
3+
import aioodbc
4+
from typing import Annotated
5+
from utils.ai_search_utils import run_ai_search_query
6+
import json
7+
import asyncio
8+
9+
USE_QUERY_CACHE = os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true"
10+
11+
PRE_RUN_QUERY_CACHE = (
12+
os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true"
13+
)
14+
15+
16+
async def get_entity_schemas(
17+
text: Annotated[
18+
str,
19+
"The text to run a semantic search against. Relevant entities will be returned.",
20+
],
21+
) -> str:
22+
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
23+
24+
Args:
25+
----
26+
text (str): The text to run the search against.
27+
28+
Returns:
29+
str: The schema of the views or tables in JSON format.
30+
"""
31+
32+
schemas = await run_ai_search_query(
33+
text,
34+
["DescriptionEmbedding"],
35+
["Entity", "EntityName", "Description", "Columns"],
36+
os.environ["AIService__AzureSearchOptions__Text2Sql__Index"],
37+
os.environ["AIService__AzureSearchOptions__Text2Sql__SemanticConfig"],
38+
top=3,
39+
)
40+
41+
for schema in schemas:
42+
entity = schema["Entity"]
43+
database = os.environ["Text2Sql__DatabaseName"]
44+
schema["SelectFromEntity"] = f"{database}.{entity}"
45+
46+
return json.dumps(schemas, default=str)
47+
48+
49+
async def query_execution(sql_query: str) -> list[dict]:
50+
"""Run the SQL query against the database.
51+
52+
Args:
53+
----
54+
sql_query (str): The SQL query to run against the database.
55+
56+
Returns:
57+
-------
58+
list[dict]: The results of the SQL query.
59+
"""
60+
connection_string = os.environ["Text2Sql__DatabaseConnectionString"]
61+
async with await aioodbc.connect(dsn=connection_string) as sql_db_client:
62+
async with sql_db_client.cursor() as cursor:
63+
await cursor.execute(sql_query)
64+
65+
columns = [column[0] for column in cursor.description]
66+
67+
rows = await cursor.fetchall()
68+
results = [dict(zip(columns, returned_row)) for returned_row in rows]
69+
70+
logging.debug("Results: %s", results)
71+
return results
72+
73+
74+
async def fetch_queries_from_cache(question: str) -> str:
75+
"""Fetch the queries from the cache based on the question.
76+
77+
Args:
78+
----
79+
question (str): The question to use to fetch the queries.
80+
81+
Returns:
82+
-------
83+
str: The formatted string of the queries fetched from the cache. This is injected into the prompt.
84+
"""
85+
cached_schemas = await run_ai_search_query(
86+
question,
87+
["QuestionEmbedding"],
88+
["Question", "SqlQueryDecomposition", "Schemas"],
89+
os.environ["AIService__AzureSearchOptions__Text2SqlQueryCache__Index"],
90+
os.environ["AIService__AzureSearchOptions__Text2SqlQueryCache__SemanticConfig"],
91+
top=1,
92+
include_scores=True,
93+
minimum_score=1.5,
94+
)
95+
96+
if len(cached_schemas) == 0:
97+
return None
98+
else:
99+
database = os.environ["Text2Sql__DatabaseName"]
100+
for entry in cached_schemas:
101+
for schema in entry["Schemas"]:
102+
entity = schema["Entity"]
103+
schema["SelectFromEntity"] = f"{database}.{entity}"
104+
105+
if PRE_RUN_QUERY_CACHE and len(cached_schemas) > 0:
106+
logging.info("Cached schemas: %s", cached_schemas)
107+
108+
# check the score
109+
if cached_schemas[0]["@search.reranker_score"] > 2.75:
110+
logging.info("Score is greater than 3")
111+
112+
sql_queries = cached_schemas[0]["SqlQueryDecomposition"]
113+
query_result_store = {}
114+
115+
query_tasks = []
116+
117+
for sql_query in sql_queries:
118+
logging.info("SQL Query: %s", sql_query)
119+
120+
# Run the SQL query
121+
query_tasks.append(query_execution(sql_query["SqlQuery"]))
122+
123+
sql_results = await asyncio.gather(*query_tasks)
124+
125+
for sql_query, sql_result in zip(sql_queries, sql_results):
126+
query_result_store[sql_query["SqlQuery"]] = {
127+
"result": sql_result,
128+
"schemas": sql_queries["schemas"],
129+
}
130+
131+
return query_result_store
132+
133+
return {"cached_questions": cached_schemas}

0 commit comments

Comments
 (0)