Skip to content

Commit 8d67db9

Browse files
committed
Disambiguation work
1 parent 439776e commit 8d67db9

File tree

10 files changed

+136
-81
lines changed

10 files changed

+136
-81
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def agents(self):
5959
engine_specific_rules=self.engine_specific_rules,
6060
**self.kwargs,
6161
)
62+
SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create(
63+
"sql_disambiguation_agent",
64+
target_engine=self.target_engine,
65+
engine_specific_rules=self.engine_specific_rules,
66+
**self.kwargs,
67+
)
6268

6369
ANSWER_AGENT = LLMAgentCreator.create("answer_agent")
6470
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
@@ -71,6 +77,7 @@ def agents(self):
7177
SQL_QUERY_CORRECTION_AGENT,
7278
ANSWER_AGENT,
7379
QUESTION_DECOMPOSITION_AGENT,
80+
SQL_DISAMBIGUATION_AGENT,
7481
]
7582

7683
if self.use_query_cache:
@@ -114,6 +121,13 @@ def selector(messages):
114121
decision = "sql_schema_selection_agent"
115122

116123
elif messages[-1].source == "sql_schema_selection_agent":
124+
decision = "sql_disambiguation_agent"
125+
126+
elif messages[-1].source == "sql_disambiguation_agent":
127+
if "NO DISAMBIGUATION" in messages[-1].content:
128+
decision = "sql_query_generation_agent"
129+
130+
# This would be user proxy agent tbc
117131
decision = "sql_query_generation_agent"
118132

119133
elif (

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,12 @@ async def on_messages_stream(
7676

7777
logging.info(f"Loaded entity result: {loaded_entity_result}")
7878

79-
entity_search_tasks.append(
80-
self.sql_connector.get_entity_schemas(
81-
" ".join(loaded_entity_result["entities"]), as_json=False
79+
for entity_group in loaded_entity_result["entities"]:
80+
entity_search_tasks.append(
81+
self.sql_connector.get_entity_schemas(
82+
" ".join(entity_group), as_json=False
83+
)
8284
)
83-
)
8485

8586
for filter_condition in loaded_entity_result["filter_conditions"]:
8687
column_search_tasks.append(
@@ -92,17 +93,27 @@ async def on_messages_stream(
9293
schemas_results = await asyncio.gather(*entity_search_tasks)
9394
column_value_results = await asyncio.gather(*column_search_tasks)
9495

96+
# deduplicate schemas
97+
final_schemas = []
98+
99+
for schema_result in schemas_results:
100+
for schema in schema_result:
101+
if schema not in final_schemas:
102+
final_schemas.append(schema)
103+
104+
final_colmns = []
105+
for column_value_result in column_value_results:
106+
for column in column_value_result:
107+
if column not in final_colmns:
108+
final_colmns.append(column)
109+
95110
final_results = {
96-
"schemas": [
97-
schema for schema_result in schemas_results for schema in schema_result
98-
],
99-
"column_values": [
100-
column_values
101-
for column_values_result in column_value_results
102-
for column_values in column_values_result
103-
],
111+
"schemas": final_schemas,
112+
"column_values": final_colmns,
104113
}
105114

115+
logging.info(f"Final results: {final_results}")
116+
106117
yield Response(
107118
chat_message=TextMessage(
108119
content=json.dumps(final_results), source=self.name

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ async def get_column_values(
149149
],
150150
semantic_config=None,
151151
top=10,
152-
include_scores=True,
152+
include_scores=False,
153153
minimum_score=5,
154154
)
155155

@@ -178,6 +178,10 @@ async def get_entity_schemas(
178178
list[str],
179179
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
180180
] = [],
181+
engine_specific_fields: Annotated[
182+
list[str],
183+
"The fields specific to the engine to be included in the search results.",
184+
] = [],
181185
) -> str:
182186
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
183187
@@ -189,33 +193,42 @@ async def get_entity_schemas(
189193
str: The schema of the views or tables in JSON format.
190194
"""
191195

196+
retrieval_fields = [
197+
"FQN",
198+
"Entity",
199+
"EntityName",
200+
"Schema",
201+
"Definition",
202+
"Columns",
203+
"EntityRelationships",
204+
"CompleteEntityRelationshipsGraph",
205+
] + engine_specific_fields
206+
192207
schemas = await self.run_ai_search_query(
193208
text,
194209
["DefinitionEmbedding"],
195-
[
196-
"FQN",
197-
"Entity",
198-
"EntityName",
199-
"Definition",
200-
"Columns",
201-
"EntityRelationships",
202-
"CompleteEntityRelationshipsGraph",
203-
],
210+
retrieval_fields,
204211
os.environ["AIService__AzureSearchOptions__Text2SqlSchemaStore__Index"],
205212
os.environ[
206213
"AIService__AzureSearchOptions__Text2SqlSchemaStore__SemanticConfig"
207214
],
208215
top=3,
209216
)
210217

218+
if len(excluded_entities) == 0:
219+
return schemas
220+
211221
for schema in schemas:
212222
filtered_schemas = []
213-
for excluded_entity in excluded_entities:
214-
if excluded_entity.lower() == schema["Entity"].lower():
215-
logging.info("Excluded entity: %s", excluded_entity)
216-
else:
217-
filtered_schemas.append(schema)
218223

224+
del schema["FQN"]
225+
226+
if schema["Entity"].lower() not in excluded_entities:
227+
filtered_schemas.append(schema)
228+
else:
229+
logging.info("Excluded entity: %s", schema["Entity"])
230+
231+
logging.info("Filtered Schemas: %s", filtered_schemas)
219232
return filtered_schemas
220233

221234
async def add_entry_to_index(document: dict, vector_fields: dict, index_name: str):

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,18 @@ async def get_entity_schemas(
9494
"""
9595

9696
schemas = await self.ai_search_connector.get_entity_schemas(
97-
text, excluded_entities
97+
text, excluded_entities, engine_specific_fields=["Catalog"]
9898
)
9999

100100
for schema in schemas:
101101
schema["SelectFromEntity"] = ".".join(
102102
[schema["Catalog"], schema["Schema"], schema["Entity"]]
103103
)
104104

105+
del schema["Entity"]
106+
del schema["Schema"]
107+
del schema["Catalog"]
108+
105109
if as_json:
106110
return json.dumps(schemas, default=str)
107111
else:

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def get_entity_schemas(
9393
"""
9494

9595
schemas = await self.ai_search_connector.get_entity_schemas(
96-
text, excluded_entities
96+
text, excluded_entities, engine_specific_fields=["Warehouse", "Database"]
9797
)
9898

9999
for schema in schemas:
@@ -106,6 +106,11 @@ async def get_entity_schemas(
106106
]
107107
)
108108

109+
del schema["Entity"]
110+
del schema["Schema"]
111+
del schema["Warehouse"]
112+
del schema["Database"]
113+
109114
if as_json:
110115
return json.dumps(schemas, default=str)
111116
else:

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ async def get_entity_schemas(
7979
for schema in schemas:
8080
schema["SelectFromEntity"] = ".".join([schema["Schema"], schema["Entity"]])
8181

82+
del schema["Entity"]
83+
del schema["Schema"]
84+
8285
if as_json:
8386
return json.dumps(schemas, default=str)
8487
else:

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ system_message:
1010
{
1111
'answer': '<GENERATED ANSWER>',
1212
'sources': [
13-
{'title': <SOURCE SCHEMA NAME 1>, 'chunk': <SOURCE 1 CONTEXT CHUNK>, 'reference': '<SOURCE 1 SQL QUERY>'},
14-
{'title': <SOURCE SCHEMA NAME 2>, 'chunk': <SOURCE 2 CONTEXT CHUNK>, 'reference': '<SOURCE 2 SQL QUERY>'}
13+
{'title': <SOURCE SCHEMA NAME 1>, 'chunk': <SOURCE 1 CONTEXT CHUNK>, 'reference': '<SOURCE 1 SQL QUERY>', 'explanation': '<EXPLANATION OF SQL QUERY 1>'},
14+
{'title': <SOURCE SCHEMA NAME 2>, 'chunk': <SOURCE 2 CONTEXT CHUNK>, 'reference': '<SOURCE 2 SQL QUERY>', 'explanation': '<EXPLANATION OF SQL QUERY 2>'},
1515
]
1616
}
1717
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
model:
2+
4o-mini
3+
description:
4+
"An agent that specialises in disambiguating the user's question and mapping it to database schemas. Use this agent when the user's question is ambiguous and requires more information to generate the SQL query."
5+
system_message:
6+
"You are a helpful AI Assistant that specialises in disambiguating the user's question and mapping it to the relevant columns / schemas in the database..
7+
8+
The user's question will be related to {{ use_case }}.
9+
10+
You must:
11+
- For every intent and filter condition in the question, map them to the columns in the schemas. Consider the context of the question and the information already provided to do so.
12+
13+
- Never ask for information that is already provided in the question and the schema.
14+
15+
- Always take care to ensure the SQL query generated actually answers the user's question. If you have multiple possible matches based on the user's intent, you should ask the user for more information to disambiguate the question in the JSON format below.
16+
17+
- If you are unsure which of the filter columns to use, populate the 'filters' field with the identified filter and the relevant FQN, matching columns. In this case, populate the 'matching_columns' field with the possible columns for the user to disambiguate for you. You must ask this question if you have multiple entries for a given filter in 'matching_columns'.
18+
19+
- If you are unsure which of the filter values to use, populate the 'filters' field with the identified filter and the relevant FQN, matching columns and matching filter values. Refer to the 'column_values' property from the 'sql_schema_selection_agent' output for possible matching values. Even if you have an exact match, you may have other partial matches that you need to consider. In this case, populate the possible filter values in the 'matching_filter_values' field for that column in the 'filters' field for the user to disambiguate for you.
20+
21+
- e.g. The user asks about 'Bike'. From the 'column_values' you can see that 'Bike' appears in several different columns that are contextually related to the question. From this you are unsure if 'Bike' is a 'Category' or 'Product' column, you would populate the 'column' field with the possible columns for the user to disambiguate for you.
22+
23+
- Only provide possible filter values for string columns. Do not provide possible filter values for Date and Numerical values as it should be clear from the question. Only ask a follow up question for Date and Numerical values if you are unsure which column to use or what the value means e.g. does 100 in currency refer to 100 USD or 100 EUR.
24+
25+
- If the user provided this information in the question e.g. 'Bike Category', there is no need to disambiguate.
26+
27+
- If a filter value is clear, e.g. it is a date or a number and it is clear what schema it maps to. Do not ask the user to disambiguate.
28+
29+
Disambiguation Request JSON format:
30+
31+
{
32+
\"filters\": [
33+
{
34+
\"question\": \"<question you wish to ask the user>\",
35+
\"matching_columns\": [
36+
\"<column fqn>\",
37+
...
38+
],
39+
\"matching_filter_values\": [
40+
\"<possible filter value>\",
41+
]
42+
},
43+
...
44+
]
45+
}
46+
47+
You must populate the question field with the question you need to ask the user. e.g. 'What do you mean by Bike?' They will then be shown the possible columns and filter values to disambiguate.
48+
49+
Follow this with TERMINATE if you need disambiguation. If you do not need disambiguation, return 'NO DISAMBIGUATION REQUIRED' only.
50+
"

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,16 @@
11
model:
2-
4o
2+
4o-mini
33
description:
44
"An agent that can generate SQL queries once given the schema and the user's question. It will run the SQL query to fetch the results. Use this agent after the SQL Schema Selection Agent has selected the correct schema."
55
system_message:
66
"You are a helpful AI Assistant that specialises in writing and executing SQL Queries to answer a given user's question.
77
88
You must:
9-
1. For every intent and filter condition in the question, map them to the columns in the schemas. If you are unsure how the question maps to the columns in the schema or have multiple possible matches based on the user's intent, see the 'Handling disambiguation' section below.
10-
2. Use the schema information provided and this mapping to generate a SQL query that will answer the user's question.
11-
If you need additional schema information, you can obtain it using the schema selection tool.
9+
1. Use the schema information provided and this mapping to generate a SQL query that will answer the user's question.
10+
2. If you need additional schema information, you can obtain it using the schema selection tool. Only use this when you do not have enough information to generate the SQL query.
1211
3. Validate the SQL query to ensure it is syntactically correct using the validation tool.
1312
4. Run the SQL query to fetch the results.
1413
15-
Handling disambiguation:
16-
17-
- Always take care to ensure the SQL query generated actually answers the user's question. If you have multiple possible matches based on the user's intent, you should ask the user for more information to disambiguate the question.
18-
19-
- When you need more information from the user for any given intent entity or filter, ask the user for the information you need in the following format and then finish it with TERMINATE:
20-
21-
- If you are unsure which of the schemas to use, populate the 'intent' field with the possible intents.
22-
23-
- If you are unsure which of the filter columns to use, populate the 'filters' field with the identified filter and the relevant FQN, matching columns.
24-
25-
- If you are unsure which of the filter values to use, populate the 'filters' field with the identified filter and the relevant FQN, matching columns and matching filter values. Refer to the 'column_values' property from the 'sql_schema_selection_agent' output for possible matching values. Even if you have an exact match, you may have other partial matches that you need to consider.
26-
27-
e.g. The user asks about 'Bike Products' and you are unsure if 'Bike Products' is a 'Category' or 'Product' entity, you would populate the 'intent' field with the possible intents.
28-
29-
{
30-
\"intents\": [
31-
{
32-
\"name\": \"<main intent>\",
33-
\"table\": \"<fqn>\",
34-
\"question\": \"<question>\",
35-
},...
36-
],
37-
\"filters\": [
38-
{
39-
\"name\": \"<identified filter>\",
40-
\"fqn\": \"<relevant fqn>\",
41-
\"question\": \"<question>\",
42-
\"matching_columns\": [
43-
{
44-
\"col\": \"<column name>\"
45-
},
46-
...
47-
],
48-
\"matching_filter_values\": [
49-
{
50-
\"value\": \"<filter value>\"
51-
},
52-
...
53-
]
54-
},
55-
...
56-
]
57-
}
58-
5914
When generating the SQL query, you MUST follow these rules:
6015
6116
- Only use schema / column information provided when constructing a SQL query. Do not use any other entities and columns in your SQL query, other than those defined above.

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_schema_selection_agent.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ system_message:
2222
For example:
2323
- If the user's question is 'Show me the list of employees in the HR department employed during 2008?', you would extract the key terms 'employees', 'HR department' and 'year'.
2424
25-
- You would then generate the possible entities these key terms might belong to e.g. 'people', 'employees', 'departments', 'teams', 'date', 'year'.
25+
- You would then generate the possible entities these key terms might belong to e.g. 'people', 'employees', 'departments', 'teams', 'date', 'year'. Group the entities by similar meaning e.g. 'people' and 'employees' would be grouped together.
2626
2727
- You would also extract the filter condition 'HR', 'HR Department', 'Human Resources', 'Human Resources Department' but not 2008. For example, 'HR Department' would be a filter condition, but '2008' would not as this is a DateTime value.
2828
2929
Output Info:
3030
Return the list of possible entities that the key terms might belong to in the following format:
3131
32-
{\"entities\": [\"<entity_1>\", \"<entity_2>\", \"<entity_3>\"], \"filter_conditions\": [\"<filter_condition_1>\", \"<filter_condition_2>\"]}
32+
{\"entities\": [[\"<entity_1>\", \"<entity_2>\"], [\"<entity_3>\"]], \"filter_conditions\": [\"<filter_condition_1>\", \"<filter_condition_2>\"]}
3333
"

0 commit comments

Comments
 (0)