Skip to content

Commit 83a068b

Browse files
committed
Add sanitizier
1 parent 0ab47b4 commit 83a068b

File tree

6 files changed

+103
-9
lines changed

6 files changed

+103
-9
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ def invalid_identifiers(self) -> list[str]:
5656
"SHOW DATABASES",
5757
]
5858

59+
def sanitize_identifier(self, identifier: str) -> str:
60+
"""Sanitize the identifier to ensure it is valid.
61+
62+
Args:
63+
----
64+
identifier (str): The identifier to sanitize.
65+
66+
Returns:
67+
-------
68+
str: The sanitized identifier.
69+
"""
70+
return f"`{identifier}`"
71+
5972
async def query_execution(
6073
self,
6174
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgres_sql.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ def invalid_identifiers(self) -> list[str]:
4848
"PGP_PUB_DECRYPT()", # (from pgcrypto extension) Asymmetric decryption function
4949
]
5050

51+
def sanitize_identifier(self, identifier: str) -> str:
52+
"""Sanitize the identifier to ensure it is valid.
53+
54+
Args:
55+
----
56+
identifier (str): The identifier to sanitize.
57+
58+
Returns:
59+
-------
60+
str: The sanitized identifier.
61+
"""
62+
return f'"{identifier}"'
63+
5164
async def query_execution(
5265
self,
5366
sql_query: Annotated[str, "The SQL query to run against the database."],

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,19 @@ def invalid_identifiers(self) -> list[str]:
7272
"QUERY_MEMORY_USAGE",
7373
]
7474

75+
def sanitize_identifier(self, identifier: str) -> str:
76+
"""Sanitize the identifier to ensure it is valid.
77+
78+
Args:
79+
----
80+
identifier (str): The identifier to sanitize.
81+
82+
Returns:
83+
-------
84+
str: The sanitized identifier.
85+
"""
86+
return f'"{identifier}"'
87+
7588
async def query_execution(
7689
self,
7790
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from jinja2 import Template
1212
import json
1313
from text_2_sql_core.utils.database import DatabaseEngineSpecificFields
14+
import re
1415

1516

1617
class SqlConnector(ABC):
@@ -40,19 +41,16 @@ def __init__(self):
4041
@abstractmethod
4142
def engine_specific_rules(self) -> str:
4243
"""Get the engine specific rules."""
43-
pass
4444

4545
@property
4646
@abstractmethod
4747
def invalid_identifiers(self) -> list[str]:
4848
"""Get the invalid identifiers upon which a sql query is rejected."""
49-
pass
5049

5150
@property
5251
@abstractmethod
5352
def engine_specific_fields(self) -> list[str]:
5453
"""Get the engine specific fields."""
55-
pass
5654

5755
@property
5856
def excluded_engine_specific_fields(self):
@@ -85,6 +83,19 @@ async def query_execution(
8583
list[dict]: The results of the SQL query.
8684
"""
8785

86+
@abstractmethod
87+
def sanitize_identifier(self, identifier: str) -> str:
88+
"""Sanitize the identifier to ensure it is valid.
89+
90+
Args:
91+
----
92+
identifier (str): The identifier to sanitize.
93+
94+
Returns:
95+
-------
96+
str: The sanitized identifier.
97+
"""
98+
8899
async def get_column_values(
89100
self,
90101
text: Annotated[
@@ -204,6 +215,26 @@ async def query_execution_with_limit(
204215
default=str,
205216
)
206217

218+
def clean_query(self, sql_query: str) -> str:
219+
"""Clean the SQL query to ensure it is valid.
220+
221+
Args:
222+
----
223+
sql_query (str): The SQL query to clean.
224+
225+
Returns:
226+
-------
227+
str: The cleaned SQL query.
228+
"""
229+
single_line_query = sql_query.strip().replace("\n", " ")
230+
cleaned_query = re.sub(
231+
r'(?<!["\[\w])\b([a-zA-Z_][a-zA-Z0-9_-]*)\b(?!["\]])',
232+
lambda m: self.sanitize_identifier(m.group(1)),
233+
single_line_query,
234+
)
235+
236+
return cleaned_query
237+
207238
async def query_validation(
208239
self,
209240
sql_query: Annotated[
@@ -213,7 +244,7 @@ async def query_validation(
213244
) -> Union[bool | list[dict]]:
214245
"""Validate the SQL query."""
215246
try:
216-
cleaned_query = sql_query.strip().replace("\n", " ")
247+
cleaned_query = self.clean_query(sql_query)
217248
logging.info("Validating SQL Query: %s", cleaned_query)
218249
parsed_queries = sqlglot.parse(
219250
cleaned_query,
@@ -249,14 +280,12 @@ def handle_node(node):
249280
detected_invalid_identifiers.append(identifier)
250281

251282
if len(detected_invalid_identifiers) > 0:
252-
logging.error(
253-
"SQL Query contains invalid identifiers: %s",
254-
detected_invalid_identifiers,
255-
)
256-
return (
283+
error_message = (
257284
"SQL Query contains invalid identifiers: %s"
258285
% detected_invalid_identifiers
259286
)
287+
logging.error(error_message)
288+
return False, None, error_message
260289

261290
except sqlglot.errors.ParseError as e:
262291
logging.error("SQL Query is invalid: %s", e.errors)

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sqlite_sql.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ def engine_specific_fields(self) -> list[str]:
4444
"""Get the engine specific fields."""
4545
return [] # SQLite doesn't use warehouses, catalogs, or separate databases
4646

47+
def sanitize_identifier(self, identifier: str) -> str:
48+
"""Sanitize the identifier to ensure it is valid.
49+
50+
Args:
51+
----
52+
identifier (str): The identifier to sanitize.
53+
54+
Returns:
55+
-------
56+
str: The sanitized identifier.
57+
"""
58+
return f'"{identifier}"'
59+
4760
async def query_execution(
4861
self,
4962
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,19 @@ def invalid_identifiers(self) -> list[str]:
6565
"VERSION",
6666
]
6767

68+
def sanitize_identifier(self, identifier: str) -> str:
69+
"""Sanitize the identifier to ensure it is valid.
70+
71+
Args:
72+
----
73+
identifier (str): The identifier to sanitize.
74+
75+
Returns:
76+
-------
77+
str: The sanitized identifier.
78+
"""
79+
return f"[{identifier}]"
80+
6881
async def query_execution(
6982
self,
7083
sql_query: Annotated[

0 commit comments

Comments
 (0)