Skip to content

Commit e6d62a3

Browse files
committed
Add postgres identifiers
1 parent f9d335c commit e6d62a3

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@ def engine_specific_fields(self) -> list[str]:
2121
"""Get the engine specific fields."""
2222
return [DatabaseEngineSpecificFields.DATABASE]
2323

24+
@property
25+
def invalid_identifiers(self) -> list[str]:
26+
"""Get the invalid identifiers upon which a sql query is rejected."""
27+
28+
return [
29+
"CURRENT_USER", # Returns the name of the current user
30+
"SESSION_USER", # Returns the name of the user that initiated the session
31+
"USER", # Returns the name of the current user
32+
"CURRENT_ROLE", # Returns the current role
33+
"CURRENT_DATABASE", # Returns the name of the current database
34+
"CURRENT_SCHEMA()", # Returns the name of the current schema
35+
"CURRENT_SETTING()", # Returns the value of a specified configuration parameter
36+
"PG_CURRENT_XACT_ID()", # Returns the current transaction ID
37+
# (if the extension is enabled) Provides a view of query statistics
38+
"PG_STAT_STATEMENTS()",
39+
"PG_SLEEP()", # Delays execution by the specified number of seconds
40+
"CLIENT_ADDR()", # Returns the IP address of the client (from pg_stat_activity)
41+
"CLIENT_HOSTNAME()", # Returns the hostname of the client (from pg_stat_activity)
42+
"PGP_SYM_DECRYPT()", # (from pgcrypto extension) Symmetric decryption function
43+
"PGP_PUB_DECRYPT()", # (from pgcrypto extension) Asymmetric decryption function
44+
]
45+
2446
async def query_execution(
2547
self,
2648
sql_query: Annotated[str, "The SQL query to run against the database."],

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def handle_node(node):
172172

173173
for token in expressions + identifiers:
174174
if isinstance(token, Parameter):
175-
identifier = token.this.this
175+
identifier = str(token.this.this).upper()
176176
else:
177177
identifier = str(token).strip("()").upper()
178178

0 commit comments

Comments
 (0)