Skip to content

Commit 09c31ff

Browse files
committed
Add a query validation for attempting to find malicious sql queries
1 parent 6688775 commit 09c31ff

File tree

4 files changed

+157
-2
lines changed

4 files changed

+157
-2
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,35 @@ def __init__(self):
1717

1818
self.database_engine = DatabaseEngine.DATABRICKS
1919

20+
@property
21+
def invalid_identifiers(self) -> list[str]:
22+
"""Get the invalid identifiers upon which a sql query is rejected."""
23+
return [
24+
# Session and system variables
25+
"CURRENT_CATALOG",
26+
"CURRENT_DATABASE",
27+
"CURRENT_USER",
28+
"SESSION_USER",
29+
"CURRENT_ROLE",
30+
"CURRENT_QUERY",
31+
"CURRENT_WAREHOUSE",
32+
"SESSION_ID",
33+
# System metadata functions
34+
"DATABASE",
35+
"USER",
36+
# Potentially unsafe built-in functions
37+
"CURRENT_USER",
38+
"SESSION_USER",
39+
"SYSTEM",
40+
"SHOW",
41+
"DESCRIBE",
42+
"EXPLAIN",
43+
"SET",
44+
"SHOW TABLES",
45+
"SHOW COLUMNS",
46+
"SHOW DATABASES",
47+
]
48+
2049
async def query_execution(
2150
self,
2251
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,48 @@ def __init__(self):
1717

1818
self.database_engine = DatabaseEngine.SNOWFLAKE
1919

20+
@property
21+
def invalid_identifiers(self) -> list[str]:
22+
"""Get the invalid identifiers upon which a sql query is rejected."""
23+
return [
24+
"CURRENT_CLIENT",
25+
"CURRENT_IP_ADDRESS",
26+
"CURRENT_REGION",
27+
"CURRENT_VERSION",
28+
"ALL_USER_NAMES",
29+
"CURRENT_ACCOUNT",
30+
"CURRENT_ACCOUNT_NAME",
31+
"CURRENT_ORGANIZATION_NAME",
32+
"CURRENT_ROLE",
33+
"CURRENT_AVAILABLE_ROLES",
34+
"CURRENT_SECONDARY_ROLES",
35+
"CURRENT_SESSION",
36+
"CURRENT_STATEMENT",
37+
"CURRENT_TRANSACTION",
38+
"CURRENT_USER",
39+
"GETVARIABLE",
40+
"LAST_QUERY_ID",
41+
"LAST_TRANSACTION",
42+
"CURRENT_DATABASE",
43+
"CURRENT_ROLE_TYPE",
44+
"CURRENT_SCHEMA",
45+
"CURRENT_SCHEMAS",
46+
"CURRENT_WAREHOUSE",
47+
"INVOKER_ROLE",
48+
"INVOKER_SHARE",
49+
"IS_APPLICATION_ROLE_IN_SESSION",
50+
"IS_DATABASE_ROLE_IN_SESSION",
51+
"IS_GRANTED_TO_INVOKER_ROLE",
52+
"IS_INSTANCE_ROLE_IN_SESSION",
53+
"IS_ROLE_IN_SESSION",
54+
"POLICY_CONTEXT",
55+
"CURRENT_SESSION_USER",
56+
"SESSION_ID",
57+
"QUERY_START_TIME",
58+
"QUERY_ELAPSED_TIME",
59+
"QUERY_MEMORY_USAGE",
60+
]
61+
2062
async def query_execution(
2163
self,
2264
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from text_2_sql_core.connectors.factory import ConnectorFactory
77
import asyncio
88
import sqlglot
9+
from sqlglot.expressions import Parameter, Select, Identifier
910
from abc import ABC, abstractmethod
1011
from jinja2 import Template
1112
import json
@@ -29,6 +30,12 @@ def __init__(self):
2930

3031
self.database_engine = None
3132

33+
@property
34+
@abstractmethod
35+
def invalid_identifiers(self) -> list[str]:
36+
"""Get the invalid identifiers upon which a sql query is rejected."""
37+
pass
38+
3239
@abstractmethod
3340
async def query_execution(
3441
self,
@@ -123,11 +130,49 @@ async def query_validation(
123130
"""Validate the SQL query."""
124131
try:
125132
logging.info("Validating SQL Query: %s", sql_query)
126-
sqlglot.transpile(
133+
parsed_queries = sqlglot.parse(
127134
sql_query,
128135
read=self.database_engine.value.lower(),
129-
error_level=sqlglot.ErrorLevel.RAISE,
130136
)
137+
138+
expressions = []
139+
identifiers = []
140+
141+
def handle_node(node):
142+
if isinstance(node, Select):
143+
# Extract expressions
144+
for expr in node.expressions:
145+
expressions.append(expr)
146+
elif isinstance(node, Identifier):
147+
# Extract identifiers
148+
identifiers.append(node.this)
149+
150+
detected_invalid_identifiers = []
151+
152+
for parsed_query in parsed_queries:
153+
for node in parsed_query.walk():
154+
handle_node(node)
155+
156+
for token in expressions + identifiers:
157+
if isinstance(token, Parameter):
158+
identifier = token.this.this
159+
else:
160+
identifier = str(token).strip("()").upper()
161+
162+
if identifier in self.invalid_identifiers:
163+
logging.warning("Detected invalid identifier: %s", identifier)
164+
detected_invalid_identifiers.append(identifier)
165+
166+
if len(detected_invalid_identifiers) > 0:
167+
logging.error(
168+
"SQL Query contains invalid identifiers: %s",
169+
detected_invalid_identifiers,
170+
)
171+
return (
172+
"SQL Query contains invalid identifiers: %s"
173+
% detected_invalid_identifiers
174+
)
175+
131176
except sqlglot.errors.ParseError as e:
132177
logging.error("SQL Query is invalid: %s", e.errors)
133178
return e.errors

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,45 @@ def __init__(self):
1616

1717
self.database_engine = DatabaseEngine.TSQL
1818

19+
@property
20+
def invalid_identifiers(self) -> list[str]:
21+
"""Get the invalid identifiers upon which a sql query is rejected."""
22+
return [
23+
"CONNECTIONS",
24+
"CPU_BUSY",
25+
"CURSOR_ROWS",
26+
"DATEFIRST",
27+
"DBTS",
28+
"ERROR",
29+
"FETCH_STATUS",
30+
"IDENTITY",
31+
"IDLE",
32+
"IO_BUSY",
33+
"LANGID",
34+
"LANGUAGE",
35+
"LOCK_TIMEOUT",
36+
"MAX_CONNECTIONS",
37+
"MAX_PRECISION",
38+
"NESTLEVEL",
39+
"OPTIONS",
40+
"PACK_RECEIVED",
41+
"PACK_SENT",
42+
"PACKET_ERRORS",
43+
"PROCID",
44+
"REMSERVER",
45+
"ROWCOUNT",
46+
"SERVERNAME",
47+
"SERVICENAME",
48+
"SPID",
49+
"TEXTSIZE",
50+
"TIMETICKS",
51+
"TOTAL_ERRORS",
52+
"TOTAL_READ",
53+
"TOTAL_WRITE",
54+
"TRANCOUNT",
55+
"VERSION",
56+
]
57+
1958
async def query_execution(
2059
self,
2160
sql_query: Annotated[

0 commit comments

Comments
 (0)