Skip to content

Commit 82ede21

Browse files
Fix import error for text2sql and improve naming (#139)
* Update naming * Fix poor naming in connectors * Update cli
1 parent 2a54f4a commit 82ede21

File tree

10 files changed

+83
-54
lines changed

10 files changed

+83
-54
lines changed

text_2_sql/autogen/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ Each agent can be configured with specific parameters and prompts to optimize it
134134

135135
## Query Cache Implementation Details
136136

137-
The vector based with query cache uses the `fetch_queries_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first.
137+
The vector based with query cache uses the `fetch_sql_queries_with_schemas_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first.
138138

139139
If the score of the top result is higher than the defined threshold, the query will be executed against the target data source and the results included in the prompt. This allows us to prompt the LLM to evaluated whether it can use these results to answer the question, **without further SQL Query generation** to speed up the process.
140140

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,14 @@ async def consume_inner_messages_from_agentic_flow(
152152
# Search for specific message types and add them to the final output object
153153
if isinstance(parsed_message, dict):
154154
# Check if the message contains pre-run results
155-
if ("contains_pre_run_results" in parsed_message) and (
156-
parsed_message["contains_pre_run_results"] is True
155+
if (
156+
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
157+
in parsed_message
158+
) and (
159+
parsed_message[
160+
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
161+
]
162+
is True
157163
):
158164
logging.info("Contains pre-run results")
159165
for pre_run_sql_query, pre_run_result in parsed_message[

text_2_sql/previous_iterations/semantic_kernel/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ This method is called by the Semantic Kernel framework automatically, when instr
134134

135135
The search text passed is vectorised against the entity level **Description** columns. A hybrid Semantic Reranking search is applied against the **EntityName**, **Entity**, **Columns/Name** fields.
136136

137-
#### fetch_queries_from_cache()
137+
#### fetch_sql_queries_with_schemas_from_cache()
138138

139-
The vector based with query cache uses the `fetch_queries_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first.
139+
The vector based with query cache uses the `fetch_sql_queries_with_schemas_from_cache()` method to fetch the most relevant previous query and injects it into the prompt before the initial LLM call. The use of Auto-Function Calling here is avoided to reduce the response time as the cache index will always be used first.
140140

141141
If the score of the top result is higher than the defined threshold, the query will be executed against the target data source and the results included in the prompt. This allows us to prompt the LLM to evaluated whether it can use these results to answer the question, **without further SQL Query generation** to speed up the process.
142142

text_2_sql/previous_iterations/semantic_kernel/plugins/vector_based_sql_plugin/vector_based_sql_plugin.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def fetch_schemas_from_store(self, search: str) -> list[dict]:
137137

138138
return schemas
139139

140-
async def fetch_queries_from_cache(self, question: str) -> str:
140+
async def fetch_sql_queries_with_schemas_from_cache(self, question: str) -> str:
141141
"""Fetch the queries from the cache based on the question.
142142
143143
Args:
@@ -151,7 +151,7 @@ async def fetch_queries_from_cache(self, question: str) -> str:
151151
if not self.use_query_cache:
152152
return None
153153

154-
cached_schemas = await self.ai_search.run_ai_search_query(
154+
sql_queries_with_schemas = await self.ai_search.run_ai_search_query(
155155
question,
156156
["QuestionEmbedding"],
157157
["Question", "SqlQueryDecomposition"],
@@ -164,26 +164,28 @@ async def fetch_queries_from_cache(self, question: str) -> str:
164164
minimum_score=1.5,
165165
)
166166

167-
if len(cached_schemas) == 0:
167+
if len(sql_queries_with_schemas) == 0:
168168
return None
169169
else:
170170
database = os.environ["Text2Sql__DatabaseName"]
171-
for entry in cached_schemas["SqlQueryDecomposition"]:
171+
for entry in sql_queries_with_schemas["SqlQueryDecomposition"]:
172172
for schema in entry["Schemas"]:
173173
entity = schema["Entity"]
174174
schema["SelectFromEntity"] = f"{database}.{entity}"
175175

176176
self.schemas[entity] = schema
177177

178178
pre_fetched_results_string = ""
179-
if self.pre_run_query_cache and len(cached_schemas) > 0:
180-
logging.info("Cached schemas: %s", cached_schemas)
179+
if self.pre_run_query_cache and len(sql_queries_with_schemas) > 0:
180+
logging.info(
181+
"Cached SQL Queries with Schemas: %s", sql_queries_with_schemas
182+
)
181183

182184
# check the score
183-
if cached_schemas[0]["@search.reranker_score"] > 2.75:
185+
if sql_queries_with_schemas[0]["@search.reranker_score"] > 2.75:
184186
logging.info("Score is greater than 3")
185187

186-
sql_queries = cached_schemas[0]["SqlQueryDecomposition"]
188+
sql_queries = sql_queries_with_schemas[0]["SqlQueryDecomposition"]
187189
query_result_store = {}
188190

189191
query_tasks = []
@@ -208,7 +210,7 @@ async def fetch_queries_from_cache(self, question: str) -> str:
208210
return pre_fetched_results_string
209211

210212
formatted_sql_cache_string = f"""[BEGIN CACHED QUERIES AND SCHEMAS]:\n{
211-
json.dumps(cached_schemas, default=str)}[END CACHED QUERIES AND SCHEMAS]"""
213+
json.dumps(sql_queries_with_schemas, default=str)}[END CACHED QUERIES AND SCHEMAS]"""
212214

213215
return formatted_sql_cache_string
214216

@@ -230,7 +232,9 @@ async def sql_prompt_injection(
230232
self.set_mode()
231233

232234
if self.use_query_cache:
233-
query_cache_string = await self.fetch_queries_from_cache(question)
235+
query_cache_string = await self.fetch_sql_queries_with_schemas_from_cache(
236+
question
237+
)
234238
else:
235239
query_cache_string = None
236240

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ def get_database_connector():
2222

2323
return SnowflakeSqlConnector()
2424
elif os.environ["Text2Sql__DatabaseEngine"].upper() == "TSQL":
25-
from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector
25+
from text_2_sql_core.connectors.tsql_sql import TsqlSqlConnector
2626

27-
return TSQLSqlConnector()
27+
return TsqlSqlConnector()
2828
elif os.environ["Text2Sql__DatabaseEngine"].upper() == "POSTGRESQL":
2929
from text_2_sql_core.connectors.postgresql_sql import (
3030
PostgresqlSqlConnector,

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def handle_node(node):
260260
logging.info("SQL Query is valid.")
261261
return True
262262

263-
async def fetch_queries_from_cache(
263+
async def fetch_sql_queries_with_schemas_from_cache(
264264
self, question: str, injected_parameters: dict = None
265265
) -> str:
266266
"""Fetch the queries from the cache based on the question.
@@ -276,14 +276,14 @@ async def fetch_queries_from_cache(
276276
# Return empty results if AI Search is disabled
277277
if not self.use_ai_search:
278278
return {
279-
"contains_pre_run_results": False,
280-
"cached_questions_and_schemas": None,
279+
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
280+
"cached_sql_queries_with_schemas_from_cache": None,
281281
}
282282

283283
if injected_parameters is None:
284284
injected_parameters = {}
285285

286-
cached_schemas = await self.ai_search_connector.run_ai_search_query(
286+
sql_queries_with_schemas = await self.ai_search_connector.run_ai_search_query(
287287
question,
288288
["QuestionEmbedding"],
289289
["Question", "SqlQueryDecomposition"],
@@ -294,51 +294,51 @@ async def fetch_queries_from_cache(
294294
minimum_score=1.5,
295295
)
296296

297-
if len(cached_schemas) == 0:
297+
if len(sql_queries_with_schemas) == 0:
298298
return {
299-
"contains_pre_run_results": False,
300-
"cached_questions_and_schemas": None,
299+
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
300+
"cached_sql_queries_with_schemas_from_cache": None,
301301
}
302302

303303
# loop through all sql queries and populate the template in place
304-
for schema in cached_schemas:
305-
sql_queries = schema["SqlQueryDecomposition"]
306-
for sql_query in sql_queries:
304+
for queries_with_schemas in sql_queries_with_schemas:
305+
for sql_query in queries_with_schemas["SqlQueryDecomposition"]:
307306
sql_query["SqlQuery"] = Template(sql_query["SqlQuery"]).render(
308307
**injected_parameters
309308
)
310309

311-
logging.info("Cached schemas: %s", cached_schemas)
312-
if self.pre_run_query_cache and len(cached_schemas) > 0:
310+
logging.info("Cached SQL Queries with Schemas: %s", sql_queries_with_schemas)
311+
if self.pre_run_query_cache and len(sql_queries_with_schemas) > 0:
313312
# check the score
314-
if cached_schemas[0]["@search.reranker_score"] > 2.75:
313+
if sql_queries_with_schemas[0]["@search.reranker_score"] > 2.75:
315314
logging.info("Score is greater than 3")
316315

317-
sql_queries = cached_schemas[0]["SqlQueryDecomposition"]
318316
query_result_store = {}
319317

320318
query_tasks = []
321319

322-
for sql_query in sql_queries:
320+
for sql_query in sql_queries_with_schemas[0]["SqlQueryDecomposition"]:
323321
logging.info("SQL Query: %s", sql_query)
324322

325323
# Run the SQL query
326324
query_tasks.append(self.query_execution(sql_query["SqlQuery"]))
327325

328326
sql_results = await asyncio.gather(*query_tasks)
329327

330-
for sql_query, sql_result in zip(sql_queries, sql_results):
328+
for sql_query, sql_result in zip(
329+
sql_queries_with_schemas[0]["SqlQueryDecomposition"], sql_results
330+
):
331331
query_result_store[sql_query["SqlQuery"]] = {
332332
"sql_rows": sql_result,
333333
"schemas": sql_query["Schemas"],
334334
}
335335

336336
return {
337-
"contains_pre_run_results": True,
338-
"cached_questions_and_schemas": query_result_store,
337+
"contains_cached_sql_queries_with_schemas_from_cache_database_results": True,
338+
"cached_sql_queries_with_schemas_from_cache": query_result_store,
339339
}
340340

341341
return {
342-
"contains_pre_run_results": False,
343-
"cached_questions_and_schemas": cached_schemas,
342+
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
343+
"cached_sql_queries_with_schemas_from_cache": sql_queries_with_schemas,
344344
}

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1111

1212

13-
class TSQLSqlConnector(SqlConnector):
13+
class TsqlSqlConnector(SqlConnector):
1414
def __init__(self):
1515
super().__init__()
1616

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,33 @@ async def process_message(
1313
) -> dict:
1414
# Initialize results dictionary
1515
cached_results = {
16-
"cached_questions_and_schemas": [],
17-
"contains_pre_run_results": False,
16+
"cached_sql_queries_with_schemas_from_cache": [],
17+
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
1818
}
1919

2020
# Process each question sequentially
2121
for message in messages:
2222
# Fetch the queries from the cache based on the question
2323
logging.info(f"Fetching queries from cache for question: {message}")
24-
cached_query = await self.sql_connector.fetch_queries_from_cache(
25-
message, injected_parameters=injected_parameters
24+
cached_query = (
25+
await self.sql_connector.fetch_sql_queries_with_schemas_from_cache(
26+
message, injected_parameters=injected_parameters
27+
)
2628
)
2729

2830
# If any question has pre-run results, set the flag
29-
if cached_query.get("contains_pre_run_results", False):
30-
cached_results["contains_pre_run_results"] = True
31+
if cached_query.get(
32+
"contains_cached_sql_queries_with_schemas_from_cache_database_results",
33+
False,
34+
):
35+
cached_results[
36+
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
37+
] = True
3138

3239
# Add the cached results for this question
33-
if cached_query.get("cached_questions_and_schemas"):
34-
cached_results["cached_questions_and_schemas"].extend(
35-
cached_query["cached_questions_and_schemas"]
40+
if cached_query.get("cached_sql_queries_with_schemas_from_cache"):
41+
cached_results["cached_sql_queries_with_schemas_from_cache"].extend(
42+
cached_query["cached_sql_queries_with_schemas_from_cache"]
3643
)
3744

3845
logging.info(f"Final cached results: {cached_results}")

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,25 @@ def create(
8383
)
8484
elif engine == DatabaseEngine.TSQL:
8585
from text_2_sql_core.data_dictionary.tsql_data_dictionary_creator import (
86-
TSQLDataDictionaryCreator,
86+
TsqlDataDictionaryCreator,
8787
)
8888

89-
data_dictionary_creator = TSQLDataDictionaryCreator(
89+
data_dictionary_creator = TsqlDataDictionaryCreator(
9090
**kwargs,
9191
)
92+
elif engine == DatabaseEngine.POSTGRESQL:
93+
from text_2_sql_core.data_dictionary.postgresql_data_dictionary_creator import (
94+
PostgresqlDataDictionaryCreator,
95+
)
96+
97+
data_dictionary_creator = PostgresqlDataDictionaryCreator(
98+
**kwargs,
99+
)
100+
else:
101+
rich_print("Text2SQL Data Dictionary Creator Failed ❌")
102+
rich_print(f"Database Engine {engine.value} is not supported.")
103+
104+
raise typer.Exit(code=1)
92105
except ImportError:
93106
detailed_error = f"""Failed to import {
94107
engine.value} Data Dictionary Creator. Check you have installed the optional dependencies for this database engine."""
@@ -100,7 +113,6 @@ def create(
100113
try:
101114
asyncio.run(data_dictionary_creator.create_data_dictionary())
102115
except Exception as e:
103-
raise e
104116
logging.error(e)
105117
rich_print("Text2SQL Data Dictionary Creator Failed ❌")
106118

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/tsql_data_dictionary_creator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import asyncio
88
import os
99
from text_2_sql_core.utils.database import DatabaseEngine
10-
from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector
10+
from text_2_sql_core.connectors.tsql_sql import TsqlSqlConnector
1111

1212

13-
class TSQLDataDictionaryCreator(DataDictionaryCreator):
13+
class TsqlDataDictionaryCreator(DataDictionaryCreator):
1414
def __init__(self, **kwargs):
1515
"""A method to initialize the DataDictionaryCreator class.
1616
@@ -25,7 +25,7 @@ def __init__(self, **kwargs):
2525

2626
self.database_engine = DatabaseEngine.TSQL
2727

28-
self.sql_connector = TSQLSqlConnector()
28+
self.sql_connector = TsqlSqlConnector()
2929

3030
"""A class to extract data dictionary information from a SQL Server database."""
3131

@@ -115,5 +115,5 @@ def extract_entity_relationships_sql_query(self) -> str:
115115

116116

117117
if __name__ == "__main__":
118-
data_dictionary_creator = TSQLDataDictionaryCreator()
118+
data_dictionary_creator = TsqlDataDictionaryCreator()
119119
asyncio.run(data_dictionary_creator.create_data_dictionary())

0 commit comments

Comments
 (0)