Skip to content

Commit 1e43911

Browse files
committed
Update changes
1 parent c55c5a2 commit 1e43911

File tree

4 files changed

+286
-274
lines changed

4 files changed

+286
-274
lines changed

text_2_sql/autogen/agents/custom_agents/sql_query_cache_agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from autogen_agentchat.base import Response
77
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
88
from autogen_core.base import CancellationToken
9-
from utils.sql import fetch_queries_from_cache
9+
from utils.sql import SqlHelper
1010
import json
1111
import logging
1212

@@ -18,6 +18,8 @@ def __init__(self):
1818
"An agent that fetches the queries from the cache based on the user question.",
1919
)
2020

21+
self.sql_helper = SqlHelper()
22+
2123
@property
2224
def produced_message_types(self) -> List[type[ChatMessage]]:
2325
return [TextMessage]
@@ -41,7 +43,7 @@ async def on_messages_stream(
4143
# Fetch the queries from the cache based on the user question.
4244
logging.info("Fetching queries from cache based on the user question...")
4345

44-
cached_queries = await fetch_queries_from_cache(user_question)
46+
cached_queries = await self.sql_helper.fetch_queries_from_cache(user_question)
4547

4648
yield Response(
4749
chat_message=TextMessage(

text_2_sql/autogen/utils/ai_search.py

Lines changed: 109 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -12,121 +12,38 @@
1212
from datetime import datetime, timezone
1313

1414

15-
async def run_ai_search_query(
16-
query,
17-
vector_fields: list[str],
18-
retrieval_fields: list[str],
19-
index_name: str,
20-
semantic_config: str,
21-
top=5,
22-
include_scores=False,
23-
minimum_score: float = None,
24-
):
25-
"""Run the AI search query."""
26-
identity_type = get_identity_type()
27-
28-
async with AsyncAzureOpenAI(
29-
# This is the default and can be omitted
30-
api_key=os.environ["OpenAI__ApiKey"],
31-
azure_endpoint=os.environ["OpenAI__Endpoint"],
32-
api_version=os.environ["OpenAI__ApiVersion"],
33-
) as open_ai_client:
34-
embeddings = await open_ai_client.embeddings.create(
35-
model=os.environ["OpenAI__EmbeddingModel"], input=query
36-
)
37-
38-
# Extract the embedding vector
39-
embedding_vector = embeddings.data[0].embedding
40-
41-
vector_query = VectorizedQuery(
42-
vector=embedding_vector,
43-
k_nearest_neighbors=7,
44-
fields=",".join(vector_fields),
45-
)
46-
47-
if identity_type == IdentityType.SYSTEM_ASSIGNED:
48-
credential = DefaultAzureCredential()
49-
elif identity_type == IdentityType.USER_ASSIGNED:
50-
credential = DefaultAzureCredential(
51-
managed_identity_client_id=os.environ["ClientID"]
52-
)
53-
else:
54-
credential = AzureKeyCredential(
55-
os.environ["AIService__AzureSearchOptions__Key"]
56-
)
57-
async with SearchClient(
58-
endpoint=os.environ["AIService__AzureSearchOptions__Endpoint"],
59-
index_name=index_name,
60-
credential=credential,
61-
) as search_client:
62-
results = await search_client.search(
63-
top=top,
64-
semantic_configuration_name=semantic_config,
65-
search_text=query,
66-
select=",".join(retrieval_fields),
67-
vector_queries=[vector_query],
68-
query_type="semantic",
69-
query_language="en-GB",
70-
)
71-
72-
combined_results = []
73-
74-
async for result in results.by_page():
75-
async for item in result:
76-
if (
77-
minimum_score is not None
78-
and item["@search.reranker_score"] < minimum_score
79-
):
80-
continue
81-
82-
if include_scores is False:
83-
del item["@search.reranker_score"]
84-
del item["@search.score"]
85-
del item["@search.highlights"]
86-
del item["@search.captions"]
87-
88-
logging.info("Item: %s", item)
89-
combined_results.append(item)
15+
class AISearchHelper:
16+
@staticmethod
17+
async def run_ai_search_query(
18+
query,
19+
vector_fields: list[str],
20+
retrieval_fields: list[str],
21+
index_name: str,
22+
semantic_config: str,
23+
top=5,
24+
include_scores=False,
25+
minimum_score: float = None,
26+
):
27+
"""Run the AI search query."""
28+
identity_type = get_identity_type()
9029

91-
logging.info("Results: %s", combined_results)
92-
93-
return combined_results
94-
95-
96-
async def add_entry_to_index(document: dict, vector_fields: dict, index_name: str):
97-
"""Add an entry to the search index."""
98-
99-
logging.info("Document: %s", document)
100-
logging.info("Vector Fields: %s", vector_fields)
101-
102-
for field in vector_fields.keys():
103-
if field not in document.keys():
104-
logging.error(f"Field {field} is not in the document.")
105-
106-
identity_type = get_identity_type()
107-
108-
fields_to_embed = {field: document[field] for field in vector_fields}
109-
110-
document["DateLastModified"] = datetime.now(timezone.utc)
111-
112-
try:
11330
async with AsyncAzureOpenAI(
11431
# This is the default and can be omitted
11532
api_key=os.environ["OpenAI__ApiKey"],
11633
azure_endpoint=os.environ["OpenAI__Endpoint"],
11734
api_version=os.environ["OpenAI__ApiVersion"],
11835
) as open_ai_client:
11936
embeddings = await open_ai_client.embeddings.create(
120-
model=os.environ["OpenAI__EmbeddingModel"],
121-
input=fields_to_embed.values(),
37+
model=os.environ["OpenAI__EmbeddingModel"], input=query
12238
)
12339

12440
# Extract the embedding vector
125-
for i, field in enumerate(vector_fields.values()):
126-
document[field] = embeddings.data[i].embedding
41+
embedding_vector = embeddings.data[0].embedding
12742

128-
document["Id"] = base64.urlsafe_b64encode(document["Question"].encode()).decode(
129-
"utf-8"
43+
vector_query = VectorizedQuery(
44+
vector=embedding_vector,
45+
k_nearest_neighbors=7,
46+
fields=",".join(vector_fields),
13047
)
13148

13249
if identity_type == IdentityType.SYSTEM_ASSIGNED:
@@ -144,7 +61,92 @@ async def add_entry_to_index(document: dict, vector_fields: dict, index_name: st
14461
index_name=index_name,
14562
credential=credential,
14663
) as search_client:
147-
await search_client.upload_documents(documents=[document])
148-
except Exception as e:
149-
logging.error("Failed to add item to index.")
150-
logging.error("Error: %s", e)
64+
results = await search_client.search(
65+
top=top,
66+
semantic_configuration_name=semantic_config,
67+
search_text=query,
68+
select=",".join(retrieval_fields),
69+
vector_queries=[vector_query],
70+
query_type="semantic",
71+
query_language="en-GB",
72+
)
73+
74+
combined_results = []
75+
76+
async for result in results.by_page():
77+
async for item in result:
78+
if (
79+
minimum_score is not None
80+
and item["@search.reranker_score"] < minimum_score
81+
):
82+
continue
83+
84+
if include_scores is False:
85+
del item["@search.reranker_score"]
86+
del item["@search.score"]
87+
del item["@search.highlights"]
88+
del item["@search.captions"]
89+
90+
logging.info("Item: %s", item)
91+
combined_results.append(item)
92+
93+
logging.info("Results: %s", combined_results)
94+
95+
return combined_results
96+
97+
@staticmethod
98+
async def add_entry_to_index(document: dict, vector_fields: dict, index_name: str):
99+
"""Add an entry to the search index."""
100+
101+
logging.info("Document: %s", document)
102+
logging.info("Vector Fields: %s", vector_fields)
103+
104+
for field in vector_fields.keys():
105+
if field not in document.keys():
106+
logging.error(f"Field {field} is not in the document.")
107+
108+
identity_type = get_identity_type()
109+
110+
fields_to_embed = {field: document[field] for field in vector_fields}
111+
112+
document["DateLastModified"] = datetime.now(timezone.utc)
113+
114+
try:
115+
async with AsyncAzureOpenAI(
116+
# This is the default and can be omitted
117+
api_key=os.environ["OpenAI__ApiKey"],
118+
azure_endpoint=os.environ["OpenAI__Endpoint"],
119+
api_version=os.environ["OpenAI__ApiVersion"],
120+
) as open_ai_client:
121+
embeddings = await open_ai_client.embeddings.create(
122+
model=os.environ["OpenAI__EmbeddingModel"],
123+
input=fields_to_embed.values(),
124+
)
125+
126+
# Extract the embedding vector
127+
for i, field in enumerate(vector_fields.values()):
128+
document[field] = embeddings.data[i].embedding
129+
130+
document["Id"] = base64.urlsafe_b64encode(
131+
document["Question"].encode()
132+
).decode("utf-8")
133+
134+
if identity_type == IdentityType.SYSTEM_ASSIGNED:
135+
credential = DefaultAzureCredential()
136+
elif identity_type == IdentityType.USER_ASSIGNED:
137+
credential = DefaultAzureCredential(
138+
managed_identity_client_id=os.environ["ClientID"]
139+
)
140+
else:
141+
credential = AzureKeyCredential(
142+
os.environ["AIService__AzureSearchOptions__Key"]
143+
)
144+
async with SearchClient(
145+
endpoint=os.environ["AIService__AzureSearchOptions__Endpoint"],
146+
index_name=index_name,
147+
credential=credential,
148+
) as search_client:
149+
await search_client.upload_documents(documents=[document])
150+
except Exception as e:
151+
logging.error("Failed to add item to index.")
152+
logging.error("Error: %s", e)

text_2_sql/autogen/utils/llm_agent_creator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import yaml
44
from autogen_core.components.tools import FunctionTool
55
from autogen_agentchat.agents import AssistantAgent
6-
from utils.sql import query_execution, get_entity_schemas, query_validation
6+
from utils.sql import SqlHelper
77
from utils.models import MINI_MODEL
88
from jinja2 import Template
99

@@ -24,20 +24,20 @@ def get_model(cls, model_name):
2424
raise ValueError(f"Model {model_name} not found")
2525

2626
@classmethod
27-
def get_tool(cls, tool_name):
27+
def get_tool(cls, sql_helper, tool_name):
2828
if tool_name == "sql_query_execution_tool":
2929
return FunctionTool(
30-
query_execution,
30+
sql_helper.query_execution,
3131
description="Runs an SQL query against the SQL Database to extract information",
3232
)
3333
elif tool_name == "sql_get_entity_schemas_tool":
3434
return FunctionTool(
35-
get_entity_schemas,
35+
sql_helper.get_entity_schemas,
3636
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
3737
)
3838
elif tool_name == "sql_query_validation_tool":
3939
return FunctionTool(
40-
query_validation,
40+
sql_helper.query_validation,
4141
description="Validates the SQL query to ensure that it is syntactically correct for the target database engine. Use this BEFORE executing any SQL statement.",
4242
)
4343
else:
@@ -55,10 +55,12 @@ def get_property_and_render_parameters(cls, agent_file, property, parameters):
5555
def create(cls, name: str, **kwargs):
5656
agent_file = cls.load_agent_file(name)
5757

58+
sql_helper = SqlHelper()
59+
5860
tools = []
5961
if "tools" in agent_file and len(agent_file["tools"]) > 0:
6062
for tool in agent_file["tools"]:
61-
tools.append(cls.get_tool(tool))
63+
tools.append(cls.get_tool(sql_helper, tool))
6264

6365
agent = AssistantAgent(
6466
name=name,

0 commit comments

Comments
 (0)