Skip to content

Commit 0c4e982

Browse files
committed
Reduce some complexity in sql query generation
1 parent 1eb6197 commit 0c4e982

File tree

6 files changed

+36
-73
lines changed

6 files changed

+36
-73
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,13 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
4242
elif tool_name == "sql_get_entity_schemas_tool":
4343
return FunctionToolAlias(
4444
sql_helper.get_entity_schemas,
45-
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
45+
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
4646
)
4747
elif tool_name == "sql_get_column_values_tool":
4848
return FunctionToolAlias(
4949
ai_search_helper.get_column_values,
5050
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
5151
)
52-
elif tool_name == "current_datetime_tool":
53-
return FunctionToolAlias(
54-
sql_helper.get_current_datetime,
55-
description="Gets the current date and time.",
56-
)
5752
else:
5853
raise ValueError(f"Tool {tool_name} not found")
5954

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -146,26 +146,7 @@ def unified_selector(self, messages):
146146
elif current_agent == "sql_query_generation_agent":
147147
decision = "sql_query_correction_agent"
148148
elif current_agent == "sql_query_correction_agent":
149-
try:
150-
correction_result = json.loads(messages[-1].content)
151-
if isinstance(correction_result, dict):
152-
if "answer" in correction_result and "sources" in correction_result:
153-
decision = "user_proxy"
154-
elif "corrected_query" in correction_result:
155-
if correction_result.get("executing", False):
156-
decision = "sql_query_correction_agent"
157-
else:
158-
decision = "sql_query_generation_agent"
159-
elif "error" in correction_result:
160-
decision = "sql_query_generation_agent"
161-
elif isinstance(correction_result, list) and len(correction_result) > 0:
162-
if "requested_fix" in correction_result[0]:
163-
decision = "sql_query_generation_agent"
164-
165-
if decision is None:
166-
decision = "sql_query_generation_agent"
167-
except json.JSONDecodeError:
168-
decision = "sql_query_generation_agent"
149+
decision = "sql_query_correction_agent"
169150

170151
if decision:
171152
logging.info(f"Agent transition: {current_agent} -> {decision}")

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import asyncio
88
import sqlglot
99
from abc import ABC, abstractmethod
10-
from datetime import datetime
1110
from jinja2 import Template
1211
import json
1312

@@ -30,22 +29,6 @@ def __init__(self):
3029

3130
self.database_engine = None
3231

33-
def get_current_datetime(self) -> str:
34-
"""Get the current datetime."""
35-
return datetime.now().strftime("%d/%m/%Y, %H:%M:%S")
36-
37-
def get_current_date(self) -> str:
38-
"""Get the current date."""
39-
return datetime.now().strftime("%d/%m/%Y")
40-
41-
def get_current_time(self) -> str:
42-
"""Get the current time."""
43-
return datetime.now().strftime("%H:%M:%S")
44-
45-
def get_current_unix_timestamp(self) -> int:
46-
"""Get the current unix timestamp."""
47-
return int(datetime.now().timestamp())
48-
4932
@abstractmethod
5033
async def query_execution(
5134
self,
@@ -169,19 +152,6 @@ async def fetch_queries_from_cache(
169152
if injected_parameters is None:
170153
injected_parameters = {}
171154

172-
# Populate the injected_parameters
173-
if "date" not in injected_parameters:
174-
injected_parameters["date"] = self.get_current_date()
175-
176-
if "time" not in injected_parameters:
177-
injected_parameters["time"] = self.get_current_time()
178-
179-
if "datetime" not in injected_parameters:
180-
injected_parameters["datetime"] = self.get_current_datetime()
181-
182-
if "unix_timestamp" not in injected_parameters:
183-
injected_parameters["unix_timestamp"] = self.get_current_unix_timestamp()
184-
185155
cached_schemas = await self.ai_search_connector.run_ai_search_query(
186156
question,
187157
["QuestionEmbedding"],

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_response.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from pydantic import BaseModel, RootModel, Field
3+
from pydantic import BaseModel, RootModel, Field, model_validator
44
from enum import StrEnum
55

66
from typing import Literal
@@ -59,6 +59,29 @@ class AgentRequestBody(BaseModel):
5959
question: str
6060
injected_parameters: dict = Field(default_factory=dict)
6161

62+
@model_validator(mode="before")
63+
def add_defaults_to_injected_parameters(cls, values):
64+
if "injected_parameters" not in values:
65+
values["injected_parameters"] = {}
66+
67+
if "date" not in values["injected_parameters"]:
68+
values["injected_parameters"]["date"] = datetime.now().strftime("%d/%m/%Y")
69+
70+
if "time" not in values["injected_parameters"]:
71+
values["injected_parameters"]["time"] = datetime.now().strftime("%H:%M:%S")
72+
73+
if "datetime" not in values["injected_parameters"]:
74+
values["injected_parameters"]["datetime"] = datetime.now().strftime(
75+
"%d/%m/%Y, %H:%M:%S"
76+
)
77+
78+
if "unix_timestamp" not in values["injected_parameters"]:
79+
values["injected_parameters"]["unix_timestamp"] = int(
80+
datetime.now().timestamp()
81+
)
82+
83+
return values
84+
6285

6386
class AgentResponse(BaseModel):
6487
header: AgentResponseHeader | None = Field(default=None)

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ description:
55
system_message:
66
"<role_and_objective>
77
You are a SQL syntax expert specializing in converting standard SQL to {{ target_engine }}-compliant SQL. Your job is to:
8-
1. Take SQL queries with correct logic but potential syntax issues
9-
2. Fix them according to {{ target_engine }} syntax rules
10-
3. Execute the corrected queries
11-
4. Return the results
8+
1. Take SQL queries with correct logic but potential syntax issues.
9+
2. Review the output from the SQL query being run and fix them according to {{ target_engine }} syntax rules if needed.
10+
3. Execute the corrected queries if needed.
11+
4. Verify that the results will answer all of the user's questions. If not, create additional queries and run them.
12+
5. Return the results
1213
</role_and_objective>
1314
1415
<engine_specific_rules>
@@ -85,18 +86,10 @@ system_message:
8586
</error_handling>
8687
8788
<output_format>
88-
- **When query executes successfully**:
89+
- **When query executes successfully and answers all questions**:
8990
```json
9091
{
91-
\"answer\": \"<ANSWER BASED ON QUERY RESULTS>\",
92-
\"sources\": [
93-
{
94-
\"sql_result_snippet\": \"<FORMATTED QUERY RESULTS>\",
95-
\"sql_query_used\": \"<EXECUTED SQL QUERY>\",
96-
\"original_query\": \"<QUERY BEFORE CONVERSION>\",
97-
\"explanation\": \"<EXPLANATION OF CONVERSIONS AND RESULTS>\"
98-
}
99-
]
92+
\"validated\": \"<TRUE>\",
10093
}
10194
```
10295
Followed by **TERMINATE**.
@@ -138,3 +131,5 @@ system_message:
138131
"
139132
tools:
140133
- sql_query_execution_tool
134+
- sql_get_entity_schemas_tool
135+
- sql_get_column_values_tool

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,4 @@ system_message:
5454
Remember: Your job is to focus on the data relationships and logic while following basic {{ target_engine }} patterns. The correction agent will handle detailed syntax fixes and execution.
5555
"
5656
tools:
57-
- sql_get_entity_schemas_tool
58-
- current_datetime_tool
57+
- sql_query_execution_tool

0 commit comments

Comments
 (0)