Skip to content

Commit 1011c3a

Browse files
committed
Update agent setup
1 parent ddbf1ad commit 1011c3a

11 files changed

+71
-2
lines changed

text_2_sql/autogen/agentic_text_2_sql.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
from utils.llm_agent_creator import LLMAgentCreator
77
import logging
88
from custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
9+
from custom_agents.sql_schema_extraction_agent import SqlSchemaExtractionAgent
910
import json
1011

1112
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
1213
"sql_query_generation_agent",
1314
target_engine="Microsoft SQL Server",
1415
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
1516
)
16-
SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create("sql_schema_selection_agent")
17+
# SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create(
18+
# "sql_schema_selection_agent")
19+
SQL_SCHEMA_SELECTION_AGENT = SqlSchemaExtractionAgent()
1720
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
1821
"sql_query_correction_agent",
1922
target_engine="Microsoft SQL Server",
File renamed without changes.

text_2_sql/autogen/custom_agents/sql_query_cache_agent.py renamed to text_2_sql/autogen/agents/custom_agents/sql_query_cache_agent.py

File renamed without changes.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from typing import AsyncGenerator, List, Sequence
4+
5+
from autogen_agentchat.agents import BaseChatAgent
6+
from autogen_agentchat.base import Response
7+
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
8+
from autogen_core.base import CancellationToken
9+
from text_2_sql.autogen.utils.sql import get_entity_schemas
10+
from keybert import KeyBERT
11+
import logging
12+
13+
14+
class SqlSchemaExtractionAgent(BaseChatAgent):
15+
def __init__(self):
16+
super().__init__(
17+
"sql_query_cache_agent",
18+
"An agent that fetches the queries from the cache based on the user question.",
19+
)
20+
21+
@property
22+
def produced_message_types(self) -> List[type[ChatMessage]]:
23+
return [TextMessage]
24+
25+
async def on_messages(
26+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
27+
) -> Response:
28+
# Calls the on_messages_stream.
29+
response: Response | None = None
30+
async for message in self.on_messages_stream(messages, cancellation_token):
31+
if isinstance(message, Response):
32+
response = message
33+
assert response is not None
34+
return response
35+
36+
async def on_messages_stream(
37+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
38+
) -> AsyncGenerator[AgentMessage | Response, None]:
39+
user_question = messages[0].content
40+
41+
# Fetch the queries from the cache based on the user question.
42+
logging.info("Fetching queries from cache based on the user question...")
43+
44+
kw_model = KeyBERT()
45+
46+
top_keywords = kw_model.extract_keywords(
47+
user_question, keyphrase_ngram_range=(1, 3), top_n=5
48+
)
49+
50+
# Extract just the key phrases (ignoring the score)
51+
key_phrases = [keyword[0] for keyword in top_keywords]
52+
53+
# Join them into a string list
54+
key_phrases_str = ", ".join(key_phrases)
55+
56+
entity_schemas = await get_entity_schemas(key_phrases_str)
57+
58+
logging.info(entity_schemas)
59+
60+
yield Response(
61+
chat_message=TextMessage(content=entity_schemas, source=self.name)
62+
)
63+
64+
async def on_reset(self, cancellation_token: CancellationToken) -> None:
65+
pass
File renamed without changes.

text_2_sql/autogen/llm_agents/question_decomposition_agent.yaml renamed to text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml

File renamed without changes.

text_2_sql/autogen/llm_agents/sql_query_correction_agent.yaml renamed to text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml

File renamed without changes.

text_2_sql/autogen/llm_agents/sql_query_generation_agent.yaml renamed to text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml

File renamed without changes.

text_2_sql/autogen/llm_agents/sql_schema_selection_agent.yaml renamed to text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml

File renamed without changes.

text_2_sql/autogen/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ azure-identity
88
python-dotenv
99
openai
1010
jinja2
11+
keybert

0 commit comments

Comments
 (0)