Skip to content

Commit 3c165bb

Browse files
authored
update ruff, cleanup database methods on testclient (#63)
* update ruff, cleanup database_exist method on testclient Changes: - ruff is now version 0.10.0 - ruff versions of precommit and hatch fmt use the same rules - apply ruff rules - cleanup in DatabaseTestClient the database_exist method - the database_exist method is now more lenient against missing permission errors or missing databases for all dbs * remove dependency on sqlalchemy_utils, improve testclient methods * fix typings, readd sqlalchemy_utils as debug dependency * use mysql special user root when possible * add quoting tests * harden against segfaults
1 parent 1fead99 commit 3c165bb

17 files changed

+284
-81
lines changed

.github/workflows/test-suite.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,5 @@ jobs:
9393
run: "hatch run test:check_types"
9494
- name: "Run tests"
9595
if: steps.filters.outputs.src == 'true' || steps.filters.outputs.workflows == 'true' || github.event.schedule != ''
96-
run: env TEST_NO_RISK_SEGFAULTS=true hatch test
96+
# test with xdist as still segfaults occur
97+
run: env TEST_NO_RISK_SEGFAULTS=true hatch test -- -n 1 --dist no

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ repos:
1313
- id: end-of-file-fixer
1414
- id: trailing-whitespace
1515
- repo: https://github.com/charliermarsh/ruff-pre-commit
16-
rev: v0.6.4
16+
rev: v0.10.0
1717
hooks:
1818
- id: ruff
1919
args: ["--fix"]

databasez/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from databasez.core import Database, DatabaseURL
22

3-
__version__ = "0.11.1"
3+
__version__ = "0.11.2"
44

55
__all__ = ["Database", "DatabaseURL"]

databasez/core/database.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ async def decr_refcount(self) -> bool:
328328
return False
329329

330330
async def connect_hook(self) -> None:
331-
"""Refcount protected connect hook. Executed begore engine and global connection setup."""
331+
"""Refcount protected connect hook. Executed before engine and global connection setup."""
332332

333333
async def connect(self) -> bool:
334334
"""
@@ -343,9 +343,9 @@ async def connect(self) -> bool:
343343
raise RuntimeError("Subdatabases and polling are disabled")
344344
# copy when not in map
345345
if loop not in self._databases_map:
346-
assert (
347-
self._global_connection is not None
348-
), "global connection should have been set"
346+
assert self._global_connection is not None, (
347+
"global connection should have been set"
348+
)
349349
# correctly initialize force_rollback with parent value
350350
database = self.__class__(
351351
self, force_rollback=bool(self.force_rollback), full_isolation=False
@@ -405,7 +405,9 @@ async def disconnect(
405405
assert not self._databases_map, "sub databases still active, force terminate them"
406406
for sub_database in self._databases_map.values():
407407
await arun_coroutine_threadsafe(
408-
sub_database.disconnect(True), sub_database._loop, self.poll_interval
408+
sub_database.disconnect(True),
409+
sub_database._loop,
410+
self.poll_interval,
409411
)
410412
self._databases_map = {}
411413
assert not self._databases_map, "sub databases still active"
@@ -556,7 +558,9 @@ async def drop_all(self, meta: MetaData, timeout: float | None = None, **kwargs:
556558
@multiloop_protector(False)
557559
def _non_global_connection(
558560
self,
559-
timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout
561+
timeout: (
562+
float | None
563+
) = None, # stub for type checker, multiloop_protector handles timeout
560564
) -> Connection:
561565
if self._connection is None:
562566
_connection = self._connection = Connection(self)
@@ -624,15 +628,17 @@ def get_backends(
624628
module = importlib.import_module(imp_path)
625629
except ImportError as exc:
626630
logging.debug(
627-
f'Import of "{imp_path}" failed. This is not an error.', exc_info=exc
631+
f'Import of "{imp_path}" failed. This is not an error.',
632+
exc_info=exc,
628633
)
629634
if "+" in scheme:
630635
imp_path = f"{overwrite_path}.{scheme.split('+', 1)[0]}"
631636
try:
632637
module = importlib.import_module(imp_path)
633638
except ImportError as exc:
634639
logging.debug(
635-
f'Import of "{imp_path}" failed. This is not an error.', exc_info=exc
640+
f'Import of "{imp_path}" failed. This is not an error.',
641+
exc_info=exc,
636642
)
637643
if module is not None:
638644
break

databasez/dialects/dbapi2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from sqlalchemy.pool import AsyncAdaptedQueuePool
1515
from sqlalchemy.sql import text
1616
from sqlalchemy.util.concurrency import await_only
17-
from sqlalchemy_utils.functions.orm import quote
1817

1918
from databasez.utils import AsyncWrapper
2019

@@ -118,7 +117,8 @@ def has_table(
118117
schema: str | None = None,
119118
**kw: Any,
120119
) -> bool:
121-
stmt = text(f"select * from '{quote(connection, table_name)}' LIMIT 1")
120+
quoted = self.identifier_preparer.quote(table_name)
121+
stmt = text(f"select 1 from '{quoted}' LIMIT 1")
122122
try:
123123
connection.execute(stmt)
124124
return True

databasez/dialects/jdbc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from sqlalchemy.pool import AsyncAdaptedQueuePool
2828
from sqlalchemy.sql import sqltypes, text
2929
from sqlalchemy.util.concurrency import await_only
30-
from sqlalchemy_utils.functions.orm import quote
3130

3231
from databasez.utils import AsyncWrapper
3332

@@ -113,7 +112,8 @@ def has_table(
113112
schema: str | None = None,
114113
**kw: Any,
115114
) -> bool:
116-
stmt = text(f"select * from '{quote(connection, table_name)}' LIMIT 1")
115+
quoted = self.identifier_preparer.quote(table_name)
116+
stmt = text(f"SELECT 1 from '{quoted}' LIMIT 1")
117117
try:
118118
connection.execute(stmt)
119119
return True

databasez/sqlalchemy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ async def start(self, is_root: bool, **extra_options: Any) -> None:
5757
self.old_transaction_level = ""
5858
if extra_options:
5959
await connection.execution_options(**extra_options)
60-
assert (
61-
await connection.get_isolation_level() != "AUTOCOMMIT"
62-
), "transactions doesn't work with AUTOCOMMIT. Please specify another transaction level."
60+
assert await connection.get_isolation_level() != "AUTOCOMMIT", (
61+
"transactions doesn't work with AUTOCOMMIT. Please specify another transaction level."
62+
)
6363
if in_transaction:
6464
self.raw_transaction = await connection.begin_nested()
6565
else:

databasez/testclient.py

Lines changed: 76 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,9 @@
55

66
import sqlalchemy
77
from sqlalchemy.exc import OperationalError, ProgrammingError
8-
from sqlalchemy_utils.functions.database import _sqlite_file_exists
9-
from sqlalchemy_utils.functions.orm import quote
108

119
from databasez import Database, DatabaseURL
12-
from databasez.utils import DATABASEZ_POLL_INTERVAL, ThreadPassingExceptions
13-
14-
15-
async def _get_scalar_result(engine: Any, sql: Any) -> Any:
16-
try:
17-
async with engine.connect() as conn:
18-
return await conn.scalar(sql)
19-
except Exception:
20-
return False
10+
from databasez.utils import DATABASEZ_POLL_INTERVAL, ThreadPassingExceptions, get_quoter
2111

2212

2313
class DatabaseTestClient(Database):
@@ -146,41 +136,60 @@ async def database_exists(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) -
146136
database = url.database
147137
dialect_name = url.sqla_url.get_dialect(True).name
148138
if dialect_name == "postgresql":
149-
text = f"SELECT 1 FROM pg_database WHERE datname='{database}'"
150-
for db in (database, "postgres", "template1", "template0", None):
139+
text = "SELECT 1 FROM pg_database WHERE datname=:database"
140+
for db in (database, None, "postgres", "template1", "template0"):
151141
url = url.replace(database=db)
152-
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
153-
try:
154-
return bool(
155-
await _get_scalar_result(db_client.engine, sqlalchemy.text(text))
156-
)
157-
except (ProgrammingError, OperationalError):
158-
pass
142+
try:
143+
async with Database(
144+
url, full_isolation=False, force_rollback=False
145+
) as db_client:
146+
if await db_client.fetch_val(
147+
# if we can connect to the db, it exists
148+
"SELECT 1"
149+
if db == database
150+
else sqlalchemy.text(text).bindparams(database=database)
151+
):
152+
return True
153+
except Exception:
154+
pass
159155
return False
160156

161157
elif dialect_name == "mysql":
162-
url = url.replace(database=None)
163-
text = (
164-
"SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
165-
f"WHERE SCHEMA_NAME = '{database}'"
166-
)
167-
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
168-
return bool(await _get_scalar_result(db_client.engine, sqlalchemy.text(text)))
158+
for db in (database, None, "root"):
159+
url = url.replace(database=db)
160+
try:
161+
async with Database(
162+
url, full_isolation=False, force_rollback=False
163+
) as db_client:
164+
if await db_client.fetch_val(
165+
(
166+
# if we can connect to the db, it exists
167+
"SELECT 1"
168+
if db == database
169+
else sqlalchemy.text(
170+
"SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = :database"
171+
).bindparams(database=database)
172+
),
173+
):
174+
return True
175+
except Exception:
176+
pass
177+
return False
169178

170179
elif dialect_name == "sqlite":
171180
if database:
172-
return database == ":memory:" or _sqlite_file_exists(database)
181+
return database == ":memory:" or os.path.exists(database)
173182
else:
174183
# The default SQLAlchemy database is in memory, and :memory: is
175184
# not required, thus we should support that use case.
176185
return True
177186
else:
178-
text = "SELECT 1"
179-
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
180-
try:
181-
return bool(await _get_scalar_result(db_client.engine, sqlalchemy.text(text)))
182-
except (ProgrammingError, OperationalError):
183-
return False
187+
try:
188+
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
189+
await db_client.fetch_val("SELECT 1")
190+
return True
191+
except Exception:
192+
return False
184193

185194
@classmethod
186195
async def create_database(
@@ -191,8 +200,9 @@ async def create_database(
191200
) -> None:
192201
url = url if isinstance(url, DatabaseURL) else DatabaseURL(url)
193202
database = url.database
194-
dialect_name = url.sqla_url.get_dialect(True).name
195-
dialect_driver = url.sqla_url.get_dialect(True).driver
203+
dialect = url.sqla_url.get_dialect(True)
204+
dialect_name = dialect.name
205+
dialect_driver = dialect.driver
196206

197207
# we don't want to connect to a not existing db
198208
if dialect_name == "postgresql":
@@ -209,7 +219,10 @@ async def create_database(
209219
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
210220
):
211221
db_client = Database(
212-
url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False
222+
url,
223+
isolation_level="AUTOCOMMIT",
224+
force_rollback=False,
225+
full_isolation=False,
213226
)
214227
else:
215228
db_client = Database(url, force_rollback=False, full_isolation=False)
@@ -218,29 +231,34 @@ async def create_database(
218231
if not template:
219232
template = "template1"
220233

221-
async with db_client.engine.begin() as conn: # type: ignore
234+
async with db_client.connection() as conn:
235+
quote = get_quoter(conn.async_connection)
222236
text = (
223-
f"CREATE DATABASE {quote(conn, database)} ENCODING "
224-
f"'{encoding}' TEMPLATE {quote(conn, template)}"
237+
f"CREATE DATABASE {quote(database)} ENCODING "
238+
f"'{encoding}' TEMPLATE {quote(template)}"
225239
)
226240
await conn.execute(sqlalchemy.text(text))
227241

228242
elif dialect_name == "mysql":
229-
async with db_client.engine.begin() as conn: # type: ignore
230-
text = f"CREATE DATABASE {quote(conn, database)} CHARACTER SET = '{encoding}'"
243+
async with db_client.connection() as conn:
244+
quote = get_quoter(conn.async_connection)
245+
text = f"CREATE DATABASE {quote(database)} CHARACTER SET = '{quote(encoding)}'"
231246
await conn.execute(sqlalchemy.text(text))
232247

233248
elif dialect_name == "sqlite" and database != ":memory:":
234249
if database:
235250
# create a sqlite file
236-
async with db_client.engine.begin() as conn: # type: ignore
237-
await conn.execute(sqlalchemy.text("CREATE TABLE DB(id int)"))
238-
await conn.execute(sqlalchemy.text("DROP TABLE DB"))
251+
async with (
252+
db_client.connection() as conn,
253+
conn.transaction(force_rollback=False),
254+
):
255+
await conn.execute("CREATE TABLE _dropme_DB(id int)")
256+
await conn.execute("DROP TABLE _dropme_DB")
239257

240258
else:
241-
async with db_client.engine.begin() as conn: # type: ignore
242-
text = f"CREATE DATABASE {quote(conn, database)}"
243-
await conn.execute(sqlalchemy.text(text))
259+
async with db_client.connection() as conn:
260+
quote = get_quoter(conn.async_connection)
261+
await conn.execute(sqlalchemy.text(f"CREATE DATABASE {quote(database)}"))
244262

245263
@classmethod
246264
async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) -> None:
@@ -264,7 +282,10 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
264282
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
265283
):
266284
db_client = Database(
267-
url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False
285+
url,
286+
isolation_level="AUTOCOMMIT",
287+
force_rollback=False,
288+
full_isolation=False,
268289
)
269290
else:
270291
db_client = Database(url, force_rollback=False, full_isolation=False)
@@ -274,6 +295,7 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
274295
os.remove(database)
275296
elif dialect_name.startswith("postgres"):
276297
async with db_client.connection() as conn:
298+
quote = get_quoter(conn.async_connection)
277299
# Disconnect all users from the database we are dropping.
278300
server_version_raw = (
279301
await conn.fetch_val(
@@ -282,7 +304,7 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
282304
).split(" ")[0]
283305
version = tuple(map(int, server_version_raw.split(".")))
284306
pid_column = "pid" if (version >= (9, 2)) else "procpid"
285-
quoted_db = quote(conn.async_connection, database)
307+
quoted_db = quote(database)
286308
text = f"""
287309
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
288310
FROM pg_stat_activity
@@ -297,8 +319,10 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
297319
await conn.execute(text)
298320
else:
299321
async with db_client.connection() as conn:
300-
text = f"DROP DATABASE {quote(conn.async_connection, database)}"
301-
await conn.execute(sqlalchemy.text(text))
322+
quote = get_quoter(conn.async_connection)
323+
text = f"DROP DATABASE {quote(database)}"
324+
with contextlib.suppress(ProgrammingError):
325+
await conn.execute(sqlalchemy.text(text))
302326

303327
def drop_db_protected(self) -> None:
304328
thread = ThreadPassingExceptions(

databasez/utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from threading import Thread
99
from typing import Any, TypeVar, cast
1010

11+
from sqlalchemy.engine import Connection as SQLAConnection
12+
from sqlalchemy.engine import Dialect
13+
1114
DATABASEZ_RESULT_TIMEOUT: float | None = None
1215
# Poll with 0.1ms, this way CPU isn't at 100%
1316
DATABASEZ_POLL_INTERVAL: float = 0.0001
@@ -170,9 +173,11 @@ def run(self) -> None:
170173
self._exc_raised = exc
171174

172175
def join(self, timeout: float | int | None = None) -> None:
173-
super().join(timeout=timeout)
174-
if self._exc_raised:
175-
raise self._exc_raised
176+
try:
177+
super().join(timeout=timeout)
178+
finally:
179+
if self._exc_raised is not None:
180+
raise self._exc_raised
176181

177182

178183
MultiloopProtectorCallable = TypeVar("MultiloopProtectorCallable", bound=Callable)
@@ -232,3 +237,19 @@ def wrapper(
232237
return cast(MultiloopProtectorCallable, wrapper)
233238

234239
return _decorator
240+
241+
242+
def get_dialect(async_conn_or_dialect: SQLAConnection | Dialect, /) -> Dialect:
243+
if isinstance(async_conn_or_dialect, Dialect):
244+
return async_conn_or_dialect
245+
if hasattr(async_conn_or_dialect, "bind"):
246+
async_conn_or_dialect = async_conn_or_dialect.bind
247+
return cast("SQLAConnection", async_conn_or_dialect).dialect
248+
249+
250+
def get_quoter(async_conn_or_dialect: SQLAConnection | Dialect, /) -> Callable[[str], str]:
251+
# needs underlying async connection as object or dialect
252+
dialect = get_dialect(async_conn_or_dialect)
253+
if hasattr(dialect, "identifier_preparer"):
254+
return dialect.identifier_preparer.quote
255+
return dialect.preparer(dialect).quote

0 commit comments

Comments
 (0)