Skip to content

Commit c66c488

Browse files
committed
Use tsql instead
1 parent da303a2 commit c66c488

File tree

6 files changed

+30
-8
lines changed

6 files changed

+30
-8
lines changed

deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class DatabaseEngine(StrEnum):
3232
"""An enumeration to represent a database engine."""
3333

3434
SNOWFLAKE = "SNOWFLAKE"
35-
SQL_SERVER = "SQL_SERVER"
35+
TSQL = "TSQL"
3636
DATABRICKS = "DATABRICKS"
3737

3838

@@ -69,7 +69,7 @@ def excluded_fields_for_database_engine(self):
6969
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
7070
if self.database_engine == DatabaseEngine.SNOWFLAKE:
7171
engine_specific_fields = ["Warehouse", "Database"]
72-
elif self.database_engine == DatabaseEngine.SQL_SERVER:
72+
elif self.database_engine == DatabaseEngine.TSQL:
7373
engine_specific_fields = ["Database"]
7474
elif self.database_engine == DatabaseEngine.DATABRICKS:
7575
engine_specific_fields = ["Catalog"]

text_2_sql/autogen/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ python-dotenv
99
openai
1010
jinja2
1111
pyyaml
12+
sqlglot[rs]

text_2_sql/autogen/utils/sql.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import logging
44
import os
55
import aioodbc
6-
from typing import Annotated
6+
from typing import Annotated, Union
77
from utils.ai_search import run_ai_search_query
88
import json
99
import asyncio
10+
import sqlglot
1011

1112
USE_QUERY_CACHE = os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true"
1213

@@ -66,7 +67,12 @@ async def get_entity_schemas(
6667
return json.dumps(schemas, default=str)
6768

6869

69-
async def query_execution(sql_query: str) -> list[dict]:
70+
async def query_execution(
71+
sql_query: Annotated[
72+
str,
73+
"The SQL query to run against the database.",
74+
]
75+
) -> list[dict]:
7076
"""Run the SQL query against the database.
7177
7278
Args:
@@ -91,6 +97,21 @@ async def query_execution(sql_query: str) -> list[dict]:
9197
return results
9298

9399

100+
async def validate_sql_query(
101+
sql_query: Annotated[
102+
str,
103+
"The SQL query to run against the database.",
104+
]
105+
) -> Union[bool | list[dict]]:
106+
"""Validate the SQL query."""
107+
try:
108+
sqlglot.transpile("SELECT foo FROM (SELECT baz FROM t")
109+
except sqlglot.errors.ParseError as e:
110+
return e.errors
111+
else:
112+
return True
113+
114+
94115
async def fetch_queries_from_cache(question: str) -> str:
95116
"""Fetch the queries from the cache based on the question.
96117

text_2_sql/data_dictionary/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,6 @@ The following Databases have pre-built scripts for them:
101101

102102
- **Databricks:** `databricks_data_dictionary_creator.py`
103103
- **Snowflake:** `snowflake_data_dictionary_creator.py`
104-
- **SQL Server:** `sql_server_data_dictionary_creator.py`
104+
- **SQL Server:** `tsql_data_dictionary_creator.py`
105105

106106
If there is no pre-built script for your database engine, take one of the above as a starting point and adjust it.

text_2_sql/data_dictionary/data_dictionary_creator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class DatabaseEngine(StrEnum):
2424
"""An enumeration to represent a database engine."""
2525

2626
SNOWFLAKE = "SNOWFLAKE"
27-
SQL_SERVER = "SQL_SERVER"
27+
TSQL = "TSQL"
2828
DATABRICKS = "DATABRICKS"
2929

3030

@@ -657,7 +657,7 @@ def excluded_fields_for_database_engine(self):
657657
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
658658
if self.database_engine == DatabaseEngine.SNOWFLAKE:
659659
engine_specific_fields = ["Warehouse", "Database"]
660-
elif self.database_engine == DatabaseEngine.SQL_SERVER:
660+
elif self.database_engine == DatabaseEngine.TSQL:
661661
engine_specific_fields = ["Database"]
662662
elif self.database_engine == DatabaseEngine.DATABRICKS:
663663
engine_specific_fields = ["Catalog"]

text_2_sql/data_dictionary/sql_sever_data_dictionary_creator.py renamed to text_2_sql/data_dictionary/tsql_data_dictionary_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
super().__init__(entities, excluded_entities, excluded_schemas, single_file)
2727
self.database = os.environ["Text2Sql__DatabaseName"]
2828

29-
self.database_engine = DatabaseEngine.SQL_SERVER
29+
self.database_engine = DatabaseEngine.TSQL
3030

3131
"""A class to extract data dictionary information from a SQL Server database."""
3232

0 commit comments

Comments
 (0)