Skip to content

Commit 59a2a6d

Browse files
committed
Update
1 parent 59b6e69 commit 59a2a6d

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

deploy_ai_search/text_2_sql_query_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def get_index_fields(self) -> list[SearchableField]:
4949
vector_search_profile_name=self.vector_search_profile_name,
5050
),
5151
ComplexField(
52-
name="Decomposition",
52+
name="SqlQueryDecomposition",
5353
collection=True,
5454
fields=[
5555
SearchableField(
56-
name="SQLQuery",
56+
name="SqlQuery",
5757
type=SearchFieldDataType.String,
5858
filterable=True,
5959
),

text_2_sql/semantic_kernel/plugins/vector_based_sql_plugin/vector_based_sql_plugin.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ async def fetch_queries_from_cache(self, question: str) -> str:
143143
cached_schemas = await run_ai_search_query(
144144
question,
145145
["QuestionEmbedding"],
146-
["Question", "Query", "Schemas"],
146+
["Question", "SqlQueryDecomposition", "Schemas"],
147147
os.environ["AIService__AzureSearchOptions__Text2SqlQueryCache__Index"],
148148
os.environ[
149149
"AIService__AzureSearchOptions__Text2SqlQueryCache__SemanticConfig"
@@ -172,17 +172,27 @@ async def fetch_queries_from_cache(self, question: str) -> str:
172172
if cached_schemas[0]["@search.reranker_score"] > 2.75:
173173
logging.info("Score is greater than 3")
174174

175-
sql_query = cached_schemas[0]["Query"]
176-
schemas = cached_schemas[0]["Schemas"]
175+
sql_queries = cached_schemas[0]["SqlQueryDecomposition"]
176+
query_result_store = {}
177177

178-
logging.info("SQL Query: %s", sql_query)
178+
query_tasks = []
179179

180-
# Run the SQL query
181-
sql_result = await self.query_execution(sql_query)
182-
logging.info("SQL Query Result: %s", sql_result)
180+
for sql_query in sql_queries:
181+
logging.info("SQL Query: %s", sql_query)
183182

184-
pre_fetched_results_string = f"""[BEGIN PRE-FETCHED RESULTS FOR SQL QUERY = '{sql_query}']\n{
185-
json.dumps(sql_result, default=str)}\nSchema={json.dumps(schemas, default=str)}\n[END PRE-FETCHED RESULTS FOR SQL QUERY]\n"""
183+
# Run the SQL query
184+
query_tasks.append(self.query_execution(sql_query["SqlQuery"]))
185+
186+
sql_results = await asyncio.gather(*query_tasks)
187+
188+
for sql_query, sql_result in zip(sql_queries, sql_results):
189+
query_result_store[sql_query["SqlQuery"]] = {
190+
"result": sql_result,
191+
"schemas": sql_queries["schemas"],
192+
}
193+
194+
pre_fetched_results_string = f"""[BEGIN PRE-FETCHED RESULTS FOR CACHED SQL QUERIES]\n{
195+
json.dumps(query_result_store, default=str)}\n[END PRE-FETCHED RESULTS FOR CACHED SQL QUERIES]\n"""
186196

187197
return pre_fetched_results_string
188198

@@ -330,8 +340,10 @@ async def run_sql_query(
330340

331341
entry = {
332342
"Question": self.question,
333-
"Query": sql_query,
334-
"Schemas": cleaned_schemas,
343+
"SqlQueryDecomposition": {
344+
"SqlQuery": sql_query,
345+
"Schemas": cleaned_schemas,
346+
},
335347
}
336348
except Exception as e:
337349
logging.error("Error: %s", e)

0 commit comments

Comments
 (0)