Skip to content

Commit 1700a68

Browse files
committed
Merge branch 'main' into feature/postgres-support
2 parents 825d2c8 + 06cd094 commit 1700a68

File tree

4 files changed

+157
-2
lines changed

4 files changed

+157
-2
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,35 @@ def engine_specific_fields(self) -> list[str]:
2222
"""Get the engine specific fields."""
2323
return [DatabaseEngineSpecificFields.CATALOG]
2424

25+
@property
26+
def invalid_identifiers(self) -> list[str]:
27+
"""Get the invalid identifiers upon which a sql query is rejected."""
28+
return [
29+
# Session and system variables
30+
"CURRENT_CATALOG",
31+
"CURRENT_DATABASE",
32+
"CURRENT_USER",
33+
"SESSION_USER",
34+
"CURRENT_ROLE",
35+
"CURRENT_QUERY",
36+
"CURRENT_WAREHOUSE",
37+
"SESSION_ID",
38+
# System metadata functions
39+
"DATABASE",
40+
"USER",
41+
# Potentially unsafe built-in functions
42+
"CURRENT_USER",
43+
"SESSION_USER",
44+
"SYSTEM",
45+
"SHOW",
46+
"DESCRIBE",
47+
"EXPLAIN",
48+
"SET",
49+
"SHOW TABLES",
50+
"SHOW COLUMNS",
51+
"SHOW DATABASES",
52+
]
53+
2554
async def query_execution(
2655
self,
2756
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,48 @@ def engine_specific_fields(self) -> list[str]:
2525
DatabaseEngineSpecificFields.DATABASE,
2626
]
2727

28+
@property
29+
def invalid_identifiers(self) -> list[str]:
30+
"""Get the invalid identifiers upon which a sql query is rejected."""
31+
return [
32+
"CURRENT_CLIENT",
33+
"CURRENT_IP_ADDRESS",
34+
"CURRENT_REGION",
35+
"CURRENT_VERSION",
36+
"ALL_USER_NAMES",
37+
"CURRENT_ACCOUNT",
38+
"CURRENT_ACCOUNT_NAME",
39+
"CURRENT_ORGANIZATION_NAME",
40+
"CURRENT_ROLE",
41+
"CURRENT_AVAILABLE_ROLES",
42+
"CURRENT_SECONDARY_ROLES",
43+
"CURRENT_SESSION",
44+
"CURRENT_STATEMENT",
45+
"CURRENT_TRANSACTION",
46+
"CURRENT_USER",
47+
"GETVARIABLE",
48+
"LAST_QUERY_ID",
49+
"LAST_TRANSACTION",
50+
"CURRENT_DATABASE",
51+
"CURRENT_ROLE_TYPE",
52+
"CURRENT_SCHEMA",
53+
"CURRENT_SCHEMAS",
54+
"CURRENT_WAREHOUSE",
55+
"INVOKER_ROLE",
56+
"INVOKER_SHARE",
57+
"IS_APPLICATION_ROLE_IN_SESSION",
58+
"IS_DATABASE_ROLE_IN_SESSION",
59+
"IS_GRANTED_TO_INVOKER_ROLE",
60+
"IS_INSTANCE_ROLE_IN_SESSION",
61+
"IS_ROLE_IN_SESSION",
62+
"POLICY_CONTEXT",
63+
"CURRENT_SESSION_USER",
64+
"SESSION_ID",
65+
"QUERY_START_TIME",
66+
"QUERY_ELAPSED_TIME",
67+
"QUERY_MEMORY_USAGE",
68+
]
69+
2870
async def query_execution(
2971
self,
3072
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from text_2_sql_core.connectors.factory import ConnectorFactory
77
import asyncio
88
import sqlglot
9+
from sqlglot.expressions import Parameter, Select, Identifier
910
from abc import ABC, abstractmethod
1011
from jinja2 import Template
1112
import json
@@ -30,6 +31,12 @@ def __init__(self):
3031

3132
self.database_engine = None
3233

34+
@property
35+
@abstractmethod
36+
def invalid_identifiers(self) -> list[str]:
37+
"""Get the invalid identifiers upon which a sql query is rejected."""
38+
pass
39+
3340
@abstractmethod
3441
@property
3542
def engine_specific_fields(self) -> list[str]:
@@ -140,11 +147,49 @@ async def query_validation(
140147
"""Validate the SQL query."""
141148
try:
142149
logging.info("Validating SQL Query: %s", sql_query)
143-
sqlglot.transpile(
150+
parsed_queries = sqlglot.parse(
144151
sql_query,
145152
read=self.database_engine.value.lower(),
146-
error_level=sqlglot.ErrorLevel.RAISE,
147153
)
154+
155+
expressions = []
156+
identifiers = []
157+
158+
def handle_node(node):
159+
if isinstance(node, Select):
160+
# Extract expressions
161+
for expr in node.expressions:
162+
expressions.append(expr)
163+
elif isinstance(node, Identifier):
164+
# Extract identifiers
165+
identifiers.append(node.this)
166+
167+
detected_invalid_identifiers = []
168+
169+
for parsed_query in parsed_queries:
170+
for node in parsed_query.walk():
171+
handle_node(node)
172+
173+
for token in expressions + identifiers:
174+
if isinstance(token, Parameter):
175+
identifier = token.this.this
176+
else:
177+
identifier = str(token).strip("()").upper()
178+
179+
if identifier in self.invalid_identifiers:
180+
logging.warning("Detected invalid identifier: %s", identifier)
181+
detected_invalid_identifiers.append(identifier)
182+
183+
if len(detected_invalid_identifiers) > 0:
184+
logging.error(
185+
"SQL Query contains invalid identifiers: %s",
186+
detected_invalid_identifiers,
187+
)
188+
return (
189+
"SQL Query contains invalid identifiers: %s"
190+
% detected_invalid_identifiers
191+
)
192+
148193
except sqlglot.errors.ParseError as e:
149194
logging.error("SQL Query is invalid: %s", e.errors)
150195
return e.errors

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,45 @@ def engine_specific_fields(self) -> list[str]:
2121
"""Get the engine specific fields."""
2222
return [DatabaseEngineSpecificFields.DATABASE]
2323

24+
@property
25+
def invalid_identifiers(self) -> list[str]:
26+
"""Get the invalid identifiers upon which a sql query is rejected."""
27+
return [
28+
"CONNECTIONS",
29+
"CPU_BUSY",
30+
"CURSOR_ROWS",
31+
"DATEFIRST",
32+
"DBTS",
33+
"ERROR",
34+
"FETCH_STATUS",
35+
"IDENTITY",
36+
"IDLE",
37+
"IO_BUSY",
38+
"LANGID",
39+
"LANGUAGE",
40+
"LOCK_TIMEOUT",
41+
"MAX_CONNECTIONS",
42+
"MAX_PRECISION",
43+
"NESTLEVEL",
44+
"OPTIONS",
45+
"PACK_RECEIVED",
46+
"PACK_SENT",
47+
"PACKET_ERRORS",
48+
"PROCID",
49+
"REMSERVER",
50+
"ROWCOUNT",
51+
"SERVERNAME",
52+
"SERVICENAME",
53+
"SPID",
54+
"TEXTSIZE",
55+
"TIMETICKS",
56+
"TOTAL_ERRORS",
57+
"TOTAL_READ",
58+
"TOTAL_WRITE",
59+
"TRANCOUNT",
60+
"VERSION",
61+
]
62+
2463
async def query_execution(
2564
self,
2665
sql_query: Annotated[

0 commit comments

Comments
 (0)