Skip to content

Commit 9b070b1

Browse files
committed
optimize db dropping
1 parent 3c165bb commit 9b070b1

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

databasez/testclient.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,9 @@ async def create_database(
261261
await conn.execute(sqlalchemy.text(f"CREATE DATABASE {quote(database)}"))
262262

263263
@classmethod
264-
async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) -> None:
264+
async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL], *, use_if_exists: bool = True) -> None:
265265
url = url if isinstance(url, DatabaseURL) else DatabaseURL(url)
266+
exists_text = "IF EXISTS " if use_if_exists else ""
266267
database = url.database
267268
dialect = url.sqla_url.get_dialect(True)
268269
dialect_name = dialect.name
@@ -277,6 +278,12 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
277278
elif dialect_name != "sqlite":
278279
url = url.replace(database=None)
279280

281+
if dialect_name == "sqlite" :
282+
if database and database != ":memory:":
283+
with contextlib.suppress(FileNotFoundError):
284+
os.remove(database)
285+
return
286+
280287
if (dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}) or (
281288
dialect_name == "postgresql"
282289
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
@@ -290,10 +297,7 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
290297
else:
291298
db_client = Database(url, force_rollback=False, full_isolation=False)
292299
async with db_client:
293-
if dialect_name == "sqlite" and database and database != ":memory:":
294-
with contextlib.suppress(FileNotFoundError):
295-
os.remove(database)
296-
elif dialect_name.startswith("postgres"):
300+
if dialect_name.startswith("postgres"):
297301
async with db_client.connection() as conn:
298302
quote = get_quoter(conn.async_connection)
299303
# Disconnect all users from the database we are dropping.
@@ -314,15 +318,16 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
314318
await conn.execute(text)
315319

316320
# Drop the database.
317-
text = f"DROP DATABASE {quoted_db}"
321+
text = f"DROP DATABASE {exists_text}{quoted_db}"
318322
with contextlib.suppress(ProgrammingError):
319323
await conn.execute(text)
320324
else:
321325
async with db_client.connection() as conn:
322326
quote = get_quoter(conn.async_connection)
323-
text = f"DROP DATABASE {quote(database)}"
324327
with contextlib.suppress(ProgrammingError):
325-
await conn.execute(sqlalchemy.text(text))
328+
text = f"DROP DATABASE {exists_text}{quote(database)}"
329+
await conn.execute(text)
330+
326331

327332
def drop_db_protected(self) -> None:
328333
thread = ThreadPassingExceptions(

0 commit comments

Comments
 (0)