Skip to content

Commit 2ca97c5

Browse files
authored
feat: implement sqlglot query parsing (#23)
Implements some basic query parsing with `sqlglot`
1 parent e3701ce commit 2ca97c5

File tree

18 files changed

+1781
-707
lines changed

18 files changed

+1781
-707
lines changed

sqlspec/adapters/adbc/driver.py

Lines changed: 155 additions & 120 deletions
Large diffs are not rendered by default.

sqlspec/adapters/aiosqlite/driver.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class AiosqliteDriver(AsyncDriverAdapterProtocol["Connection"]):
1717
"""SQLite Async Driver Adapter."""
1818

1919
connection: "Connection"
20+
dialect: str = "sqlite"
2021

2122
def __init__(self, connection: "Connection") -> None:
2223
self.connection = connection
@@ -33,50 +34,23 @@ async def _with_cursor(self, connection: "Connection") -> "AsyncGenerator[Cursor
3334
finally:
3435
await cursor.close()
3536

36-
def _process_sql_params(
37-
self, sql: str, parameters: "Optional[StatementParameterType]" = None
38-
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
39-
"""Process SQL query and parameters for DB-API execution.
40-
41-
Converts named parameters (:name) to positional parameters (?) for SQLite.
42-
43-
Args:
44-
sql: The SQL query string.
45-
parameters: The parameters for the query (dict, tuple, list, or None).
46-
47-
Returns:
48-
A tuple containing the processed SQL string and the processed parameters.
49-
"""
50-
if not isinstance(parameters, dict) or not parameters:
51-
# If parameters are not a dict, or empty dict, assume positional/no params
52-
# Let the underlying driver handle tuples/lists directly
53-
return sql, parameters
54-
55-
# Convert named parameters to positional parameters
56-
processed_sql = sql
57-
processed_params: list[Any] = []
58-
for key, value in parameters.items():
59-
# Replace :key with ? in the SQL
60-
processed_sql = processed_sql.replace(f":{key}", "?")
61-
processed_params.append(value)
62-
63-
return processed_sql, tuple(processed_params)
64-
6537
async def select(
6638
self,
6739
sql: str,
6840
parameters: Optional["StatementParameterType"] = None,
6941
/,
42+
*,
7043
connection: Optional["Connection"] = None,
7144
schema_type: "Optional[type[ModelDTOT]]" = None,
45+
**kwargs: Any,
7246
) -> "list[Union[ModelDTOT, dict[str, Any]]]":
7347
"""Fetch data from the database.
7448
7549
Returns:
7650
List of row data as either model instances or dictionaries.
7751
"""
7852
connection = self._connection(connection)
79-
sql, parameters = self._process_sql_params(sql, parameters)
53+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
8054
async with self._with_cursor(connection) as cursor:
8155
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
8256
results = await cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
@@ -92,16 +66,18 @@ async def select_one(
9266
sql: str,
9367
parameters: Optional["StatementParameterType"] = None,
9468
/,
69+
*,
9570
connection: Optional["Connection"] = None,
9671
schema_type: "Optional[type[ModelDTOT]]" = None,
72+
**kwargs: Any,
9773
) -> "Union[ModelDTOT, dict[str, Any]]":
9874
"""Fetch one row from the database.
9975
10076
Returns:
10177
The first row of the query results.
10278
"""
10379
connection = self._connection(connection)
104-
sql, parameters = self._process_sql_params(sql, parameters)
80+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
10581
async with self._with_cursor(connection) as cursor:
10682
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
10783
result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
@@ -116,16 +92,18 @@ async def select_one_or_none(
11692
sql: str,
11793
parameters: Optional["StatementParameterType"] = None,
11894
/,
95+
*,
11996
connection: Optional["Connection"] = None,
12097
schema_type: "Optional[type[ModelDTOT]]" = None,
98+
**kwargs: Any,
12199
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
122100
"""Fetch one row from the database.
123101
124102
Returns:
125103
The first row of the query results.
126104
"""
127105
connection = self._connection(connection)
128-
sql, parameters = self._process_sql_params(sql, parameters)
106+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
129107
async with self._with_cursor(connection) as cursor:
130108
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
131109
result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
@@ -141,16 +119,18 @@ async def select_value(
141119
sql: str,
142120
parameters: "Optional[StatementParameterType]" = None,
143121
/,
122+
*,
144123
connection: "Optional[Connection]" = None,
145124
schema_type: "Optional[type[T]]" = None,
125+
**kwargs: Any,
146126
) -> "Union[T, Any]":
147127
"""Fetch a single value from the database.
148128
149129
Returns:
150130
The first value from the first row of results, or None if no results.
151131
"""
152132
connection = self._connection(connection)
153-
sql, parameters = self._process_sql_params(sql, parameters)
133+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
154134
async with self._with_cursor(connection) as cursor:
155135
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
156136
result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType]
@@ -164,17 +144,18 @@ async def select_value_or_none(
164144
sql: str,
165145
parameters: "Optional[StatementParameterType]" = None,
166146
/,
147+
*,
167148
connection: "Optional[Connection]" = None,
168149
schema_type: "Optional[type[T]]" = None,
150+
**kwargs: Any,
169151
) -> "Optional[Union[T, Any]]":
170152
"""Fetch a single value from the database.
171153
172154
Returns:
173155
The first value from the first row of results, or None if no results.
174156
"""
175157
connection = self._connection(connection)
176-
sql, parameters = self._process_sql_params(sql, parameters)
177-
158+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
178159
async with self._with_cursor(connection) as cursor:
179160
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
180161
result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType]
@@ -189,15 +170,17 @@ async def insert_update_delete(
189170
sql: str,
190171
parameters: Optional["StatementParameterType"] = None,
191172
/,
173+
*,
192174
connection: Optional["Connection"] = None,
175+
**kwargs: Any,
193176
) -> int:
194177
"""Insert, update, or delete data from the database.
195178
196179
Returns:
197180
Row count affected by the operation.
198181
"""
199182
connection = self._connection(connection)
200-
sql, parameters = self._process_sql_params(sql, parameters)
183+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
201184

202185
async with self._with_cursor(connection) as cursor:
203186
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
@@ -208,16 +191,18 @@ async def insert_update_delete_returning(
208191
sql: str,
209192
parameters: Optional["StatementParameterType"] = None,
210193
/,
194+
*,
211195
connection: Optional["Connection"] = None,
212196
schema_type: "Optional[type[ModelDTOT]]" = None,
197+
**kwargs: Any,
213198
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
214199
"""Insert, update, or delete data from the database and return result.
215200
216201
Returns:
217202
The first row of results.
218203
"""
219204
connection = self._connection(connection)
220-
sql, parameters = self._process_sql_params(sql, parameters)
205+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
221206

222207
async with self._with_cursor(connection) as cursor:
223208
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
@@ -234,15 +219,17 @@ async def execute_script(
234219
sql: str,
235220
parameters: Optional["StatementParameterType"] = None,
236221
/,
222+
*,
237223
connection: Optional["Connection"] = None,
224+
**kwargs: Any,
238225
) -> str:
239226
"""Execute a script.
240227
241228
Returns:
242229
Status message for the operation.
243230
"""
244231
connection = self._connection(connection)
245-
sql, parameters = self._process_sql_params(sql, parameters)
232+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
246233

247234
async with self._with_cursor(connection) as cursor:
248235
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
@@ -253,16 +240,18 @@ async def execute_script_returning(
253240
sql: str,
254241
parameters: Optional["StatementParameterType"] = None,
255242
/,
243+
*,
256244
connection: Optional["Connection"] = None,
257245
schema_type: "Optional[type[ModelDTOT]]" = None,
246+
**kwargs: Any,
258247
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
259248
"""Execute a script and return result.
260249
261250
Returns:
262251
The first row of results.
263252
"""
264253
connection = self._connection(connection)
265-
sql, parameters = self._process_sql_params(sql, parameters)
254+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
266255

267256
async with self._with_cursor(connection) as cursor:
268257
await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]

0 commit comments

Comments
 (0)