|
27 | 27 | "is_positive_int",
|
28 | 28 | "is_string",
|
29 | 29 | "is_valid_name",
|
| 30 | + "is_valid_sql_name", |
30 | 31 | "simple_join_keys_predicate",
|
31 | 32 | "unique_properties_predicate",
|
32 | 33 | ]
|
33 | 34 |
|
34 | 35 |
|
| 36 | +import builtins |
| 37 | +import keyword |
| 38 | +import re |
35 | 39 | from abc import ABC, abstractmethod
|
36 | 40 |
|
37 | 41 | import numpy as np
|
@@ -97,11 +101,189 @@ class ValidName(PyDoughPredicate):
|
97 | 101 | as the name of a PyDough graph/collection/property.
|
98 | 102 | """
|
99 | 103 |
|
| 104 | + def __init__(self): |
| 105 | + self.error_messages: dict[str, str] = { |
| 106 | + "identifier": "must be a string that is a valid Python identifier", |
| 107 | + "python_keyword": "must be a string that is not a Python reserved word or built-in name", |
| 108 | + "pydough_keyword": "must be a string that is not a PyDough reserved word", |
| 109 | + "sql_keyword": "must be a string that is not a SQL reserved word", |
| 110 | + } |
| 111 | + |
| 112 | + def _error_code(self, obj: object) -> str | None: |
| 113 | + """Return an error code if invalid, or None if valid.""" |
| 114 | + ret_val: str | None = None |
| 115 | + # Check that obj is a string |
| 116 | + if isinstance(obj, str): |
| 117 | + # Check that obj is a valid Python identifier |
| 118 | + if not obj.isidentifier(): |
| 119 | + ret_val = "identifier" |
| 120 | + # Check that obj is not a Python reserved word or built-in name |
| 121 | + elif keyword.iskeyword(obj) or hasattr(builtins, obj): |
| 122 | + ret_val = "python_keyword" |
| 123 | + # Check that obj is not a PyDough reserved word |
| 124 | + elif self._is_pydough_keyword(obj): |
| 125 | + ret_val = "pydough_keyword" |
| 126 | + else: |
| 127 | + ret_val = "identifier" |
| 128 | + |
| 129 | + return ret_val |
| 130 | + |
| 131 | + def _is_pydough_keyword(self, name: str) -> bool: |
| 132 | + """ |
| 133 | + helper: Verifies if name is a PyDough reserved word. |
| 134 | + Extend with new PyDough reserved words if required. |
| 135 | + """ |
| 136 | + # Dictionary of all registered operators pre-built from the PyDough source |
| 137 | + from pydough.pydough_operators import builtin_registered_operators |
| 138 | + |
| 139 | + # Set of collection operators |
| 140 | + PYDOUGH_RESERVED: set[str] = { |
| 141 | + "CALCULATE", |
| 142 | + "WHERE", |
| 143 | + "ORDER_BY", |
| 144 | + "TOP_K", |
| 145 | + "PARTITION", |
| 146 | + "SINGULAR", |
| 147 | + "BEST", |
| 148 | + "CROSS", |
| 149 | + } |
| 150 | + return (name in PYDOUGH_RESERVED) or (name in builtin_registered_operators()) |
| 151 | + |
100 | 152 | def accept(self, obj: object) -> bool:
|
101 |
| - return isinstance(obj, str) and obj.isidentifier() |
| 153 | + return self._error_code(obj) is None |
102 | 154 |
|
103 | 155 | def error_message(self, error_name: str) -> str:
|
104 |
| - return f"{error_name} must be a string that is a Python identifier" |
| 156 | + # Generic fallback (since we don't have the object here) |
| 157 | + return f"{error_name} must be a valid identifier and not a reserved word" |
| 158 | + |
| 159 | + def verify(self, obj: object, error_name: str) -> None: |
| 160 | + code: str | None = self._error_code(obj) |
| 161 | + if code is not None: |
| 162 | + raise PyDoughMetadataException(f"{error_name} {self.error_messages[code]}") |
| 163 | + |
| 164 | + |
| 165 | +class ValidSQLName(PyDoughPredicate): |
| 166 | + """Predicate class to check that an object is a string that can be used |
| 167 | + as the name for a SQL table path/column name. |
| 168 | + """ |
| 169 | + |
| 170 | + # Regex for unquoted SQL identifiers |
| 171 | + _UNQUOTED_SQL_IDENTIFIER = re.compile( |
| 172 | + r"^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$" |
| 173 | + ) |
| 174 | + |
| 175 | + def __init__(self): |
| 176 | + self.error_messages: dict[str, str] = { |
| 177 | + "identifier": "must have a SQL name that is a valid SQL identifier", |
| 178 | + "sql_keyword": "must have a SQL name that is not a reserved word", |
| 179 | + } |
| 180 | + |
| 181 | + def _error_code(self, obj: object) -> str | None: |
| 182 | + """Return an error code if invalid, or None if valid.""" |
| 183 | + ret_val: str | None = None |
| 184 | + # Check that obj is a string |
| 185 | + if isinstance(obj, str): |
| 186 | + # Check that obj is a valid SQL identifier |
| 187 | + if not self.is_valid_sql_identifier(obj): |
| 188 | + ret_val = "identifier" |
| 189 | + # Check that obj is not a SQL reserved word |
| 190 | + elif self._is_sql_keyword(obj): |
| 191 | + ret_val = "sql_keyword" |
| 192 | + else: |
| 193 | + ret_val = "identifier" |
| 194 | + |
| 195 | + return ret_val |
| 196 | + |
| 197 | + def is_valid_sql_identifier(self, name: str) -> bool: |
| 198 | + """ |
| 199 | + Check if a string is a valid SQL identifier. |
| 200 | +
|
| 201 | + - Unquoted: starts with letter/underscore, then letters, digits, |
| 202 | + underscores. |
| 203 | + - Double-quoted: allows any chars, but " "" " is the only valid way to |
| 204 | + include a double-quote char. |
| 205 | + - Backtick-quoted: allows any chars, but `` `` `` is the only valid |
| 206 | + way to include a backtick char. |
| 207 | + """ |
| 208 | + if not name: |
| 209 | + return False |
| 210 | + |
| 211 | + # Case 1: unquoted |
| 212 | + if self._UNQUOTED_SQL_IDENTIFIER.match(name): |
| 213 | + return True |
| 214 | + |
| 215 | + # Case 2: double quoted |
| 216 | + if name.startswith('"') and name.endswith('"'): |
| 217 | + inner = name[1:-1] |
| 218 | + # Any " must be escaped as "" |
| 219 | + return '"' not in inner.replace('""', "") |
| 220 | + |
| 221 | + # Case 3: backtick quoted |
| 222 | + if name.startswith("`") and name.endswith("`"): |
| 223 | + inner = name[1:-1] |
| 224 | + # Any ` must be escaped as `` |
| 225 | + return "`" not in inner.replace("``", "") |
| 226 | + |
| 227 | + return False |
| 228 | + |
| 229 | + # fmt: off |
| 230 | + SQL_RESERVED_KEYWORDS: set[str] = { |
| 231 | + # Query & DML |
| 232 | + "select", "from", "where", "group", "having", "distinct", "as", |
| 233 | + "join", "inner", "union", "intersect", "except", |
| 234 | + |
| 235 | + # DDL & schema |
| 236 | + "create", "alter", "drop", "table", "view", "index", "sequence", |
| 237 | + "trigger", "schema", "database", "column", "constraint", |
| 238 | + |
| 239 | + # DML |
| 240 | + "insert", "update", "delete", "into", "values", "set", |
| 241 | + |
| 242 | + # Control flow & logical |
| 243 | + "and", "or", "not", "in", "is", "like", "between", "case", "when", |
| 244 | + "then", "else", "end", "exists", |
| 245 | + |
| 246 | + # Transaction & session |
| 247 | + "begin", "commit", "rollback", "savepoint", "transaction", |
| 248 | + "lock", "grant", "revoke", |
| 249 | + |
| 250 | + # Data types |
| 251 | + "int", "integer", "bigint", "smallint", "decimal", "numeric", |
| 252 | + "float", "real", "double", "char", "varchar", "text", |
| 253 | + "timestamp", "boolean", "null", |
| 254 | + |
| 255 | + # Functions |
| 256 | + "cast", |
| 257 | + } |
| 258 | + """ |
| 259 | + Set of SQL reserved keywords that may cause conflicts when used as table or |
| 260 | + column names. This list was compiled from commonly reserved terms across |
| 261 | + multiple SQL dialects (e.g., PostgreSQL, SQLite, MySQL), with emphasis on |
| 262 | + keywords that are likely to appear in generated SQL statements. |
| 263 | + If any of these are used as identifiers, they must be properly escaped to |
| 264 | + avoid syntax errors. |
| 265 | + """ |
| 266 | + # fmt: on |
| 267 | + |
| 268 | + def _is_sql_keyword(self, name: str) -> bool: |
| 269 | + """ |
| 270 | + helper: Verifies if name is a SQL reserved word. |
| 271 | + Uses SQL_RESERVED_KEYWORDS set. |
| 272 | + Extend with new SQL reserved words if required. |
| 273 | + """ |
| 274 | + return name.lower() in self.SQL_RESERVED_KEYWORDS |
| 275 | + |
| 276 | + def accept(self, obj: object) -> bool: |
| 277 | + return self._error_code(obj) is None |
| 278 | + |
| 279 | + def error_message(self, error_name: str) -> str: |
| 280 | + # Generic fallback (since we don't have the object here) |
| 281 | + return f"{error_name} must be a valid SQL identifier and not a reserved word" |
| 282 | + |
| 283 | + def verify(self, obj: object, error_name: str) -> None: |
| 284 | + code: str | None = self._error_code(obj) |
| 285 | + if code is not None: |
| 286 | + raise PyDoughMetadataException(f"{error_name} {self.error_messages[code]}") |
105 | 287 |
|
106 | 288 |
|
107 | 289 | class NoExtraKeys(PyDoughPredicate):
|
@@ -304,6 +486,7 @@ def error_message(self, error_name: str) -> str:
|
304 | 486 | ###############################################################################
|
305 | 487 |
|
306 | 488 | is_valid_name: PyDoughPredicate = ValidName()
|
| 489 | +is_valid_sql_name: PyDoughPredicate = ValidSQLName() |
307 | 490 | is_integer = HasType(int, "integer")
|
308 | 491 | is_string = HasType(str, "string")
|
309 | 492 | is_bool = HasType(bool, "boolean")
|
|
0 commit comments