Skip to content

Commit a3bd1d0

Browse files
committed
Update work on autogen
1 parent 681665e commit a3bd1d0

File tree

6 files changed

+170
-132
lines changed

6 files changed

+170
-132
lines changed

text_2_sql/autogen/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
The implementation is written for [AutoGen](https://github.com/microsoft/autogen) in Python, although it can easily be adapted for C#.
44

5+
**Still work in progress, expect a lot of updates shortly**
6+
57
**The provided AutoGen code only implements Iterations 5 (Agentic Approach)**
68

79
## Full Logical Flow for Agentic Vector Based Approach

text_2_sql/autogen/agentic_text_2_sql.ipynb

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"Copyright (c) Microsoft Corporation.\n",
8+
"\n",
9+
"Licensed under the MIT License."
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"# Text2SQL with AutoGen & Azure OpenAI\n",
17+
"\n",
18+
"This notebook demonstrates how the AutoGen Agents can be integrated with Azure OpenAI to answer questions from the database based on the schemas provided. \n",
19+
"\n",
20+
"A multi-shot approach is used for SQL generation for more reliable results and reduced token usage. More details can be found in the README.md."
21+
]
22+
},
323
{
424
"cell_type": "code",
525
"execution_count": null,
@@ -9,7 +29,7 @@
929
"import dotenv\n",
1030
"import logging\n",
1131
"from autogen_agentchat.task import Console\n",
12-
"from agentic_text_2_sql import text_2_sql_generator"
32+
"from agentic_text_2_sql import AgenticText2Sql"
1333
]
1434
},
1535
{
@@ -30,13 +50,29 @@
3050
"dotenv.load_dotenv()"
3151
]
3252
},
53+
{
54+
"cell_type": "markdown",
55+
"metadata": {},
56+
"source": [
57+
"## Bot setup"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"metadata": {},
64+
"outputs": [],
65+
"source": [
66+
"agentic_text_2_sql = AgenticText2Sql(target_engine=\"TSQL\", engine_specific_rules=\"Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.\").agentic_flow"
67+
]
68+
},
3369
{
3470
"cell_type": "code",
3571
"execution_count": null,
3672
"metadata": {},
3773
"outputs": [],
3874
"source": [
39-
"result = text_2_sql_generator.run_stream(task=\"What are the total number of sales within 2008?\")"
75+
"result = agentic_text_2_sql.run_stream(task=\"What are the total number of sales within 2008?\")"
4076
]
4177
},
4278
{

text_2_sql/autogen/agentic_text_2_sql.py

Lines changed: 126 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -7,81 +7,131 @@
77
import logging
88
from agents.custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
99
import json
10+
import os
1011

11-
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
12-
"sql_query_generation_agent",
13-
target_engine="Microsoft SQL Server",
14-
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
15-
)
16-
SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create(
17-
"sql_schema_selection_agent",
18-
use_case="Sales data for a company that specializes in selling products online.",
19-
)
20-
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
21-
"sql_query_correction_agent",
22-
target_engine="Microsoft SQL Server",
23-
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
24-
)
25-
SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent()
26-
ANSWER_AGENT = LLMAgentCreator.create("answer_agent")
27-
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create("question_decomposition_agent")
28-
29-
30-
def text_2_sql_generator_selector_func(messages):
31-
logging.info("Messages: %s", messages)
32-
decision = None # Initialize decision variable
33-
34-
if len(messages) == 1:
35-
decision = "sql_query_cache_agent"
36-
37-
elif (
38-
messages[-1].source == "sql_query_cache_agent"
39-
and messages[-1].content is not None
40-
):
41-
cache_result = json.loads(messages[-1].content)
42-
if cache_result.get("cached_questions_and_schemas") is not None:
12+
13+
class AgenticText2Sql:
14+
def __init__(self, target_engine: str, engine_specific_rules: str):
15+
self.use_query_cache = False
16+
self.pre_run_query_cache = False
17+
18+
self.target_engine = target_engine
19+
self.engine_specific_rules = engine_specific_rules
20+
21+
self.set_mode()
22+
23+
def set_mode(self):
24+
"""Set the mode of the plugin based on the environment variables."""
25+
self.use_query_cache = (
26+
os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true"
27+
)
28+
29+
self.pre_run_query_cache = (
30+
os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true"
31+
)
32+
33+
@property
34+
def agents(self):
35+
"""Define the agents for the chat."""
36+
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
37+
"sql_query_generation_agent",
38+
target_engine=self.target_engine,
39+
engine_specific_rules=self.engine_specific_rules,
40+
)
41+
SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create(
42+
"sql_schema_selection_agent",
43+
use_case="Sales data for a company that specializes in selling products online.",
44+
)
45+
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
46+
"sql_query_correction_agent",
47+
target_engine=self.target_engine,
48+
engine_specific_rules=self.engine_specific_rules,
49+
)
50+
51+
ANSWER_AGENT = LLMAgentCreator.create("answer_agent")
52+
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
53+
"question_decomposition_agent"
54+
)
55+
56+
agents = [
57+
SQL_QUERY_GENERATION_AGENT,
58+
SQL_SCHEMA_SELECTION_AGENT,
59+
SQL_QUERY_CORRECTION_AGENT,
60+
ANSWER_AGENT,
61+
QUESTION_DECOMPOSITION_AGENT,
62+
]
63+
64+
if self.use_query_cache:
65+
SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent()
66+
agents.append(SQL_QUERY_CACHE_AGENT)
67+
68+
return agents
69+
70+
@property
71+
def termination_condition(self):
72+
"""Define the termination condition for the chat."""
73+
termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10)
74+
return termination
75+
76+
@staticmethod
77+
def selector(messages):
78+
logging.info("Messages: %s", messages)
79+
decision = None # Initialize decision variable
80+
81+
if len(messages) == 1:
82+
decision = "sql_query_cache_agent"
83+
84+
elif (
85+
messages[-1].source == "sql_query_cache_agent"
86+
and messages[-1].content is not None
87+
):
88+
cache_result = json.loads(messages[-1].content)
89+
if cache_result.get("cached_questions_and_schemas") is not None:
90+
decision = "sql_query_correction_agent"
91+
else:
92+
decision = "sql_schema_selection_agent"
93+
94+
elif messages[-1].source == "sql_query_cache_agent":
95+
decision = "question_decomposition_agent"
96+
97+
elif messages[-1].source == "question_decomposition_agent":
98+
decomposition_result = json.loads(messages[-1].content)
99+
100+
if len(decomposition_result["entities"]) == 1:
101+
decision = "sql_schema_selection_agent"
102+
else:
103+
decision = "parallel_sql_flow_agent"
104+
105+
elif messages[-1].source == "sql_schema_selection_agent":
106+
decision = "sql_query_generation_agent"
107+
108+
elif (
109+
messages[-1].source == "sql_query_correction_agent"
110+
and messages[-1].content == "VALIDATED"
111+
):
112+
decision = "answer_agent"
113+
114+
elif messages[-1].source == "sql_query_correction_agent":
43115
decision = "sql_query_correction_agent"
44-
else:
45-
decision = "sql_schema_selection_agent"
46-
47-
elif messages[-1].source == "question_decomposition_agent":
48-
decomposition_result = json.loads(messages[-1].content)
49-
50-
if len(decomposition_result["entities"]) == 1:
51-
decision = "sql_schema_selection_agent"
52-
else:
53-
decision = "parallel_sql_flow_agent"
54-
55-
elif messages[-1].source == "sql_schema_selection_agent":
56-
decision = "sql_query_generation_agent"
57-
58-
elif (
59-
messages[-1].source == "sql_query_correction_agent"
60-
and messages[-1].content == "VALIDATED"
61-
):
62-
decision = "answer_agent"
63-
64-
elif messages[-1].source == "sql_query_correction_agent":
65-
decision = "sql_query_correction_agent"
66-
67-
# Log the decision
68-
logging.info("Decision: %s", decision)
69-
70-
return decision
71-
72-
73-
termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10)
74-
text_2_sql_generator = SelectorGroupChat(
75-
[
76-
SQL_QUERY_GENERATION_AGENT,
77-
SQL_SCHEMA_SELECTION_AGENT,
78-
SQL_QUERY_CORRECTION_AGENT,
79-
SQL_QUERY_CACHE_AGENT,
80-
ANSWER_AGENT,
81-
QUESTION_DECOMPOSITION_AGENT,
82-
],
83-
allow_repeated_speaker=False,
84-
model_client=MINI_MODEL,
85-
termination_condition=termination,
86-
selector_func=text_2_sql_generator_selector_func,
87-
)
116+
117+
# Log the decision
118+
logging.info("Decision: %s", decision)
119+
120+
return decision
121+
122+
@property
123+
def agentic_flow(self):
124+
"""Run the agentic flow for the given question.
125+
126+
Args:
127+
----
128+
question (str): The question to run the agentic flow on."""
129+
agentic_flow = SelectorGroupChat(
130+
self.agents,
131+
allow_repeated_speaker=False,
132+
model_client=MINI_MODEL,
133+
termination_condition=self.termination_condition,
134+
selector_func=AgenticText2Sql.selector,
135+
)
136+
137+
return agentic_flow

text_2_sql/autogen/agents/custom_agents/parallel_sql_flow_agent.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

text_2_sql/autogen/utils/sql.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,13 @@ async def query_validation(
105105
) -> Union[bool | list[dict]]:
106106
"""Validate the SQL query."""
107107
try:
108+
logging.info("Validating SQL Query: %s", sql_query)
108109
sqlglot.transpile(sql_query)
109110
except sqlglot.errors.ParseError as e:
111+
logging.error("SQL Query is invalid: %s", e.errors)
110112
return e.errors
111113
else:
114+
logging.info("SQL Query is valid.")
112115
return True
113116

114117

text_2_sql/data_dictionary/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,6 @@ The following Databases have pre-built scripts for them:
101101

102102
- **Databricks:** `databricks_data_dictionary_creator.py`
103103
- **Snowflake:** `snowflake_data_dictionary_creator.py`
104-
- **SQL Server:** `tsql_data_dictionary_creator.py`
104+
- **TSQL:** `tsql_data_dictionary_creator.py`
105105

106106
If there is no pre-built script for your database engine, take one of the above as a starting point and adjust it.

0 commit comments

Comments
 (0)