Skip to content

Commit da303a2

Browse files
committed
Add parallel agent flow
1 parent fa26325 commit da303a2

File tree

5 files changed

+21
-21
lines changed

5 files changed

+21
-21
lines changed

text_2_sql/autogen/agentic_text_2_sql.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
target_engine="Microsoft SQL Server",
1414
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
1515
)
16-
SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create("sql_schema_selection_agent")
16+
SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create(
17+
"sql_schema_selection_agent",
18+
use_case="Sales data for a company that specializes in selling products online.",
19+
)
1720
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
1821
"sql_query_correction_agent",
1922
target_engine="Microsoft SQL Server",
@@ -42,7 +45,12 @@ def text_2_sql_generator_selector_func(messages):
4245
decision = "sql_schema_selection_agent"
4346

4447
elif messages[-1].source == "question_decomposition_agent":
45-
decision = "sql_schema_selection_agent"
48+
decomposition_result = json.loads(messages[-1].content)
49+
50+
if len(decomposition_result["entities"]) == 1:
51+
decision = "sql_schema_selection_agent"
52+
else:
53+
decision = "parallel_sql_flow_agent"
4654

4755
elif messages[-1].source == "sql_schema_selection_agent":
4856
decision = "sql_query_generation_agent"

text_2_sql/autogen/agents/custom_agents/sql_schema_extraction_agent.py renamed to text_2_sql/autogen/agents/custom_agents/parallel_sql_flow_agent.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from autogen_agentchat.base import Response
77
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
88
from autogen_core.base import CancellationToken
9-
from text_2_sql.autogen.utils.sql import get_entity_schemas
10-
from keybert import KeyBERT
9+
from utils.sql import fetch_queries_from_cache
10+
import json
1111
import logging
1212

1313

14-
class SqlSchemaExtractionAgent(BaseChatAgent):
14+
class ParallelSqlFlowAgent(BaseChatAgent):
1515
def __init__(self):
1616
super().__init__(
1717
"sql_query_cache_agent",
@@ -41,24 +41,12 @@ async def on_messages_stream(
4141
# Fetch the queries from the cache based on the user question.
4242
logging.info("Fetching queries from cache based on the user question...")
4343

44-
kw_model = KeyBERT()
45-
46-
top_keywords = kw_model.extract_keywords(
47-
user_question, keyphrase_ngram_range=(1, 3), top_n=5
48-
)
49-
50-
# Extract just the key phrases (ignoring the score)
51-
key_phrases = [keyword[0] for keyword in top_keywords]
52-
53-
# Join them into a string list
54-
key_phrases_str = ", ".join(key_phrases)
55-
56-
entity_schemas = await get_entity_schemas(key_phrases_str)
57-
58-
logging.info(entity_schemas)
44+
cached_queries = await fetch_queries_from_cache(user_question)
5945

6046
yield Response(
61-
chat_message=TextMessage(content=entity_schemas, source=self.name)
47+
chat_message=TextMessage(
48+
content=json.dumps(cached_queries), source=self.name
49+
)
6250
)
6351

6452
async def on_reset(self, cancellation_token: CancellationToken) -> None:

text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ description:
77
system_message:
88
"You are a helpful AI Assistant that specialises in selecting relevant SQL schemas to answer a given user's question.
99
10+
The user's question will be related to {{ use_case }}.
11+
1012
Perform the following steps to select the correct schema:
1113
1214
1. Extract key terms and entities from the user's question.

text_2_sql/autogen/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ azure-identity
88
python-dotenv
99
openai
1010
jinja2
11+
pyyaml

text_2_sql/semantic_kernel/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ numpy
99
seaborn
1010
pydantic
1111
openai
12+
pyyaml

0 commit comments

Comments
 (0)