55
66import sqlalchemy
77from sqlalchemy .exc import OperationalError , ProgrammingError
8- from sqlalchemy_utils .functions .database import _sqlite_file_exists
9- from sqlalchemy_utils .functions .orm import quote
108
119from databasez import Database , DatabaseURL
12- from databasez .utils import DATABASEZ_POLL_INTERVAL , ThreadPassingExceptions
13-
14-
15- async def _get_scalar_result (engine : Any , sql : Any ) -> Any :
16- try :
17- async with engine .connect () as conn :
18- return await conn .scalar (sql )
19- except Exception :
20- return False
10+ from databasez .utils import DATABASEZ_POLL_INTERVAL , ThreadPassingExceptions , get_quoter
2111
2212
2313class DatabaseTestClient (Database ):
@@ -146,41 +136,60 @@ async def database_exists(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) -
146136 database = url .database
147137 dialect_name = url .sqla_url .get_dialect (True ).name
148138 if dialect_name == "postgresql" :
149- text = f "SELECT 1 FROM pg_database WHERE datname=' { database } ' "
150- for db in (database , "postgres" , "template1" , "template0" , None ):
139+ text = "SELECT 1 FROM pg_database WHERE datname=: database"
140+ for db in (database , None , "postgres" , "template1" , "template0" ):
151141 url = url .replace (database = db )
152- async with Database (url , full_isolation = False , force_rollback = False ) as db_client :
153- try :
154- return bool (
155- await _get_scalar_result (db_client .engine , sqlalchemy .text (text ))
156- )
157- except (ProgrammingError , OperationalError ):
158- pass
142+ try :
143+ async with Database (
144+ url , full_isolation = False , force_rollback = False
145+ ) as db_client :
146+ if await db_client .fetch_val (
147+ # if we can connect to the db, it exists
148+ "SELECT 1"
149+ if db == database
150+ else sqlalchemy .text (text ).bindparams (database = database )
151+ ):
152+ return True
153+ except Exception :
154+ pass
159155 return False
160156
161157 elif dialect_name == "mysql" :
162- url = url .replace (database = None )
163- text = (
164- "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
165- f"WHERE SCHEMA_NAME = '{ database } '"
166- )
167- async with Database (url , full_isolation = False , force_rollback = False ) as db_client :
168- return bool (await _get_scalar_result (db_client .engine , sqlalchemy .text (text )))
158+ for db in (database , None , "root" ):
159+ url = url .replace (database = db )
160+ try :
161+ async with Database (
162+ url , full_isolation = False , force_rollback = False
163+ ) as db_client :
164+ if await db_client .fetch_val (
165+ (
166+ # if we can connect to the db, it exists
167+ "SELECT 1"
168+ if db == database
169+ else sqlalchemy .text (
170+ "SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = :database"
171+ ).bindparams (database = database )
172+ ),
173+ ):
174+ return True
175+ except Exception :
176+ pass
177+ return False
169178
170179 elif dialect_name == "sqlite" :
171180 if database :
172- return database == ":memory:" or _sqlite_file_exists (database )
181+ return database == ":memory:" or os . path . exists (database )
173182 else :
174183 # The default SQLAlchemy database is in memory, and :memory: is
175184 # not required, thus we should support that use case.
176185 return True
177186 else :
178- text = "SELECT 1"
179- async with Database (url , full_isolation = False , force_rollback = False ) as db_client :
180- try :
181- return bool ( await _get_scalar_result ( db_client . engine , sqlalchemy . text ( text )))
182- except ( ProgrammingError , OperationalError ) :
183- return False
187+ try :
188+ async with Database (url , full_isolation = False , force_rollback = False ) as db_client :
189+ await db_client . fetch_val ( "SELECT 1" )
190+ return True
191+ except Exception :
192+ return False
184193
185194 @classmethod
186195 async def create_database (
@@ -191,8 +200,9 @@ async def create_database(
191200 ) -> None :
192201 url = url if isinstance (url , DatabaseURL ) else DatabaseURL (url )
193202 database = url .database
194- dialect_name = url .sqla_url .get_dialect (True ).name
195- dialect_driver = url .sqla_url .get_dialect (True ).driver
203+ dialect = url .sqla_url .get_dialect (True )
204+ dialect_name = dialect .name
205+ dialect_driver = dialect .driver
196206
197207 # we don't want to connect to a not existing db
198208 if dialect_name == "postgresql" :
@@ -209,7 +219,10 @@ async def create_database(
209219 and dialect_driver in {"asyncpg" , "pg8000" , "psycopg" , "psycopg2" , "psycopg2cffi" }
210220 ):
211221 db_client = Database (
212- url , isolation_level = "AUTOCOMMIT" , force_rollback = False , full_isolation = False
222+ url ,
223+ isolation_level = "AUTOCOMMIT" ,
224+ force_rollback = False ,
225+ full_isolation = False ,
213226 )
214227 else :
215228 db_client = Database (url , force_rollback = False , full_isolation = False )
@@ -218,29 +231,34 @@ async def create_database(
218231 if not template :
219232 template = "template1"
220233
221- async with db_client .engine .begin () as conn : # type: ignore
234+ async with db_client .connection () as conn :
235+ quote = get_quoter (conn .async_connection )
222236 text = (
223- f"CREATE DATABASE { quote (conn , database )} ENCODING "
224- f"'{ encoding } ' TEMPLATE { quote (conn , template )} "
237+ f"CREATE DATABASE { quote (database )} ENCODING "
238+ f"'{ encoding } ' TEMPLATE { quote (template )} "
225239 )
226240 await conn .execute (sqlalchemy .text (text ))
227241
228242 elif dialect_name == "mysql" :
229- async with db_client .engine .begin () as conn : # type: ignore
230- text = f"CREATE DATABASE { quote (conn , database )} CHARACTER SET = '{ encoding } '"
243+ async with db_client .connection () as conn :
244+ quote = get_quoter (conn .async_connection )
245+ text = f"CREATE DATABASE { quote (database )} CHARACTER SET = '{ quote (encoding )} '"
231246 await conn .execute (sqlalchemy .text (text ))
232247
233248 elif dialect_name == "sqlite" and database != ":memory:" :
234249 if database :
235250 # create a sqlite file
236- async with db_client .engine .begin () as conn : # type: ignore
237- await conn .execute (sqlalchemy .text ("CREATE TABLE DB(id int)" ))
238- await conn .execute (sqlalchemy .text ("DROP TABLE DB" ))
251+ async with (
252+ db_client .connection () as conn ,
253+ conn .transaction (force_rollback = False ),
254+ ):
255+ await conn .execute ("CREATE TABLE _dropme_DB(id int)" )
256+ await conn .execute ("DROP TABLE _dropme_DB" )
239257
240258 else :
241- async with db_client .engine . begin () as conn : # type: ignore
242- text = f"CREATE DATABASE { quote (conn , database ) } "
243- await conn .execute (sqlalchemy .text (text ))
259+ async with db_client .connection () as conn :
260+ quote = get_quoter (conn . async_connection )
261+ await conn .execute (sqlalchemy .text (f"CREATE DATABASE { quote ( database ) } " ))
244262
245263 @classmethod
246264 async def drop_database (cls , url : Union [str , "sqlalchemy.URL" , DatabaseURL ]) -> None :
@@ -264,7 +282,10 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
264282 and dialect_driver in {"asyncpg" , "pg8000" , "psycopg" , "psycopg2" , "psycopg2cffi" }
265283 ):
266284 db_client = Database (
267- url , isolation_level = "AUTOCOMMIT" , force_rollback = False , full_isolation = False
285+ url ,
286+ isolation_level = "AUTOCOMMIT" ,
287+ force_rollback = False ,
288+ full_isolation = False ,
268289 )
269290 else :
270291 db_client = Database (url , force_rollback = False , full_isolation = False )
@@ -274,6 +295,7 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
274295 os .remove (database )
275296 elif dialect_name .startswith ("postgres" ):
276297 async with db_client .connection () as conn :
298+ quote = get_quoter (conn .async_connection )
277299 # Disconnect all users from the database we are dropping.
278300 server_version_raw = (
279301 await conn .fetch_val (
@@ -282,7 +304,7 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
282304 ).split (" " )[0 ]
283305 version = tuple (map (int , server_version_raw .split ("." )))
284306 pid_column = "pid" if (version >= (9 , 2 )) else "procpid"
285- quoted_db = quote (conn . async_connection , database )
307+ quoted_db = quote (database )
286308 text = f"""
287309 SELECT pg_terminate_backend(pg_stat_activity.{ pid_column } )
288310 FROM pg_stat_activity
@@ -297,8 +319,10 @@ async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) ->
297319 await conn .execute (text )
298320 else :
299321 async with db_client .connection () as conn :
300- text = f"DROP DATABASE { quote (conn .async_connection , database )} "
301- await conn .execute (sqlalchemy .text (text ))
322+ quote = get_quoter (conn .async_connection )
323+ text = f"DROP DATABASE { quote (database )} "
324+ with contextlib .suppress (ProgrammingError ):
325+ await conn .execute (sqlalchemy .text (text ))
302326
303327 def drop_db_protected (self ) -> None :
304328 thread = ThreadPassingExceptions (
0 commit comments