Skip to content

Commit ea8ce44

Browse files
authored
Add Reserved Words DB test pipeline for SQLite (#428)
Adds reserved_words tests to custom_datasets test pipeline Improves metadata validation for complex qualified SQL identifiers that contain quotes and reserved or escaped characters.
1 parent 05944e4 commit ea8ce44

12 files changed

+2834
-27
lines changed

pydough/errors/error_utils.py

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import keyword
3838
import re
3939
from abc import ABC, abstractmethod
40+
from enum import Enum, auto
4041

4142
import numpy as np
4243

@@ -167,28 +168,125 @@ class ValidSQLName(PyDoughPredicate):
167168
as the name for a SQL table path/column name.
168169
"""
169170

170-
# Regex for unquoted SQL identifiers
171-
_UNQUOTED_SQL_IDENTIFIER = re.compile(
172-
r"^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$"
173-
)
171+
# Single-part unquoted SQL identifier (no dots here).
172+
UNQUOTED_SQL_IDENTIFIER = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
173+
"""
174+
Regex pattern for a single-part unquoted SQL identifier (without dots).
175+
"""
174176

175177
def __init__(self):
176178
self.error_messages: dict[str, str] = {
177179
"identifier": "must have a SQL name that is a valid SQL identifier",
178180
"sql_keyword": "must have a SQL name that is not a reserved word",
179181
}
180182

183+
def _split_identifier(self, name: str) -> list[str]:
184+
"""
185+
Split a potentially qualified SQL identifier into parts.
186+
187+
Behavior:
188+
- Dots (.) **outside** quotes/backticks separate parts.
189+
- Escaped double quotes "" are allowed inside a quoted name ("...").
190+
- Escaped backticks `` are allowed inside a backtick name (`...`).
191+
- Dots inside quoted/backtick names are literal characters and do not split.
192+
- Returned parts include their surrounding quotes/backticks if present.
193+
(This is intentional, since quoted and unquoted names will be validated differently later.)
194+
- Empty parts may be returned for cases like:
195+
* ".field" → ["", "field"]
196+
* "schema." → ["schema", ""]
197+
* "db..tbl" → ["db", "", "tbl"]
198+
(Validation will decide if empty parts are allowed.)
199+
200+
Notes:
201+
- After closing a quoted/backtick identifier, parsing continues in the same token
202+
until a dot (.) is seen or the string ends. Quotes themselves do not trigger splitting.
203+
- If spaces or other invalid characters appear in a part, the validator will
204+
reject that token later.
205+
206+
Examples:
207+
>>> _split_identifier('schema.table')
208+
['schema', 'table']
209+
210+
>>> _split_identifier('"foo"."bar"')
211+
['"foo"', '"bar"']
212+
213+
>>> _split_identifier('db."table.name"')
214+
['db', '"table.name"']
215+
216+
>>> _split_identifier('`a``b`.`c``d`')
217+
['`a``b`', '`c``d`']
218+
219+
>>> _split_identifier('.field')
220+
['', 'field']
221+
222+
>>> _split_identifier('field.')
223+
['field', '']
224+
"""
225+
226+
class split_states(Enum):
227+
START = auto()
228+
UNQUOTED = auto()
229+
DOUBLE_QUOTE = auto()
230+
BACKTICK = auto()
231+
232+
parts: list[str] = []
233+
start_idx: int = 0
234+
state: split_states = split_states.START
235+
length = len(name)
236+
ii: int = 0
237+
238+
while ii < length:
239+
ch: str = name[ii]
240+
match state:
241+
case split_states.START:
242+
match ch:
243+
case '"':
244+
state = split_states.DOUBLE_QUOTE
245+
ii += 1
246+
case "`":
247+
state = split_states.BACKTICK
248+
ii += 1
249+
case _:
250+
state = split_states.UNQUOTED
251+
case split_states.UNQUOTED:
252+
if ch == ".":
253+
parts.append(name[start_idx:ii])
254+
start_idx = ii + 1
255+
state = split_states.START
256+
ii += 1
257+
case split_states.DOUBLE_QUOTE:
258+
if ch == '"':
259+
if (ii + 1 < length) and (name[ii + 1] == '"'):
260+
ii += 1
261+
else:
262+
state = split_states.UNQUOTED
263+
ii += 1
264+
case split_states.BACKTICK:
265+
if ch == "`":
266+
if (ii + 1 < length) and (name[ii + 1] == "`"):
267+
ii += 1
268+
else:
269+
state = split_states.UNQUOTED
270+
ii += 1
271+
parts.append(name[start_idx:ii])
272+
return parts
273+
181274
def _error_code(self, obj: object) -> str | None:
182275
"""Return an error code if invalid, or None if valid."""
183276
ret_val: str | None = None
184277
# Check that obj is a string
185278
if isinstance(obj, str):
186-
# Check that obj is a valid SQL identifier
187-
if not self.is_valid_sql_identifier(obj):
188-
ret_val = "identifier"
189-
# Check that obj is not a SQL reserved word
190-
elif self._is_sql_keyword(obj):
191-
ret_val = "sql_keyword"
279+
# Check each part of a qualified name: db.schema.table
280+
for part in self._split_identifier(obj):
281+
# Check that obj is a valid SQL identifier
282+
# Empty parts (e.g., leading/trailing dots) are invalid
283+
if not part or not self.is_valid_sql_identifier(part):
284+
ret_val = "identifier"
285+
break
286+
# Check that obj is not a SQL reserved word
287+
if self._is_sql_keyword(part):
288+
ret_val = "sql_keyword"
289+
break
192290
else:
193291
ret_val = "identifier"
194292

@@ -209,7 +307,7 @@ def is_valid_sql_identifier(self, name: str) -> bool:
209307
return False
210308

211309
# Case 1: unquoted
212-
if self._UNQUOTED_SQL_IDENTIFIER.match(name):
310+
if self.UNQUOTED_SQL_IDENTIFIER.match(name):
213311
return True
214312

215313
# Case 2: double quoted

tests/conftest.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def get_test_graph_by_name() -> graph_fetcher:
193193
test_graph_location: dict[str, str] = {
194194
"synthea": "synthea_graph.json",
195195
"world_development_indicators": "world_development_indicators_graph.json",
196+
"keywords": "reserved_words_graph.json",
196197
}
197198

198199
@cache
@@ -616,13 +617,24 @@ def sqlite_custom_datasets_connection() -> DatabaseContext:
616617
"""
617618
Returns the SQLITE database connection with all the custom datasets attached.
618619
"""
619-
commands: list[str] = [
620-
"cd tests/gen_data",
621-
"rm -fv synthea.db",
622-
"rm -fv world_development_indicators.db",
623-
"sqlite3 synthea.db < init_synthea_sqlite.sql",
624-
"sqlite3 world_development_indicators.db < init_world_indicators_sqlite.sql",
620+
gen_data_path: str = "tests/gen_data"
621+
# Dataset tuple format: (schema_name, db_file_name, init_sql_file_name)
622+
SQLite_datasets: list[tuple[str, str, str]] = [
623+
("synthea", "synthea.db", "init_synthea_sqlite.sql"),
624+
("wdi", "world_development_indicators.db", "init_world_indicators_sqlite.sql"),
625+
("keywords", "reserved_words.db", "init_reserved_words_sqlite.sql"),
625626
]
627+
628+
# List of shell commands required to re-create all the db files
629+
commands: list[str] = [f"cd {gen_data_path}"]
630+
# Collect all db_file_names into the rm command
631+
rm_command: str = "rm -fv " + " ".join(
632+
db_file for (_, db_file, _) in SQLite_datasets
633+
)
634+
commands.append(rm_command)
635+
# Add one sqlite3 command per dataset
636+
for _, db_file, init_sql in SQLite_datasets:
637+
commands.append(f"sqlite3 {db_file} < {init_sql}")
626638
# Get the shell commands required to re-create all the db files
627639
shell_cmd: str = "; ".join(commands)
628640

@@ -633,16 +645,11 @@ def sqlite_custom_datasets_connection() -> DatabaseContext:
633645
# Central in-memory connection
634646
connection: sqlite3.Connection = sqlite3.connect(":memory:")
635647

636-
# Dict: schema_name → database file path
637-
dbs: dict[str, str] = {
638-
"synthea": "tests/gen_data/synthea.db",
639-
"wdi": "tests/gen_data/world_development_indicators.db",
640-
}
641-
642-
# Attach them all
643-
for schema, path in dbs.items():
644-
path = os.path.join(base_dir, path)
648+
# Use (schema_name, db_file_name info) on SQLite_datasets to ATTACH DBs
649+
for schema, db_file, _ in SQLite_datasets:
650+
path: str = os.path.join(base_dir, gen_data_path, db_file)
645651
connection.execute(f"ATTACH DATABASE '{path}' AS {schema}")
652+
646653
return DatabaseContext(DatabaseConnection(connection), DatabaseDialect.SQLITE)
647654

648655

0 commit comments

Comments
 (0)