Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 89 additions & 85 deletions sqlalchemy_utils/functions/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import os
from collections.abc import Mapping, Sequence
from contextlib import contextmanager
from copy import copy

import sqlalchemy as sa
Expand Down Expand Up @@ -442,6 +443,36 @@ def _sqlite_file_exists(database):
return header[:16] == b'SQLite format 3\x00'


@contextmanager
def _create_engine(url, use_primary_db=False, **engine_kwargs):
"""A context manager that provies a SQLAlchemy engine.

:param url: A SQLAlchemy engine URL.
:param use_primary_db: If True, connects to the primary database of the
given database server. This is necessary for operations such as creating
or dropping databases. If False, connects to the database specified in the url.
:param engine_kwargs: Additional keyword arguments passed to sqlalchemy's create_engine function.
"""
url = make_url(url)
dialect_name = url.get_dialect().name

if use_primary_db:
if dialect_name == 'postgresql':
url = _set_url_database(url, database='postgres')
elif dialect_name == 'mssql':
url = _set_url_database(url, database='master')
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database='defaultdb')
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)

engine = sa.create_engine(url, **engine_kwargs)
try:
yield engine
finally:
engine.dispose()


def database_exists(url):
"""Check if a database exists.

Expand All @@ -466,55 +497,48 @@ def database_exists(url):
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
engine = None
try:
if dialect_name == 'postgresql':
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
for db in (database, 'postgres', 'template1', 'template0', None):
url = _set_url_database(url, database=db)
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
if dialect_name == 'postgresql':
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
for db in (database, 'postgres', 'template1', 'template0', None):
url = _set_url_database(url, database=db)
with _create_engine(url) as engine:
try:
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
pass
return False
return False

elif dialect_name == 'mysql':
url = _set_url_database(url, database=None)
engine = sa.create_engine(url)
elif dialect_name == 'mysql':
url = _set_url_database(url, database=None)
with _create_engine(url) as engine:
text = (
'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA '
"WHERE SCHEMA_NAME = '%s'" % database
)
return bool(_get_scalar_result(engine, sa.text(text)))

elif dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
engine = sa.create_engine(url)
if database:
return database == ':memory:' or _sqlite_file_exists(database)
else:
# The default SQLAlchemy database is in memory, and :memory: is
# not required, thus we should support that use case.
return True
elif dialect_name == 'mssql':
elif dialect_name == 'sqlite':
if database:
return database == ':memory:' or _sqlite_file_exists(database)
else:
# The default SQLAlchemy database is in memory, and :memory: is
# not required, thus we should support that use case.
return True
elif dialect_name == 'mssql':
url = _set_url_database(url, database='master')
with _create_engine(url, isolation_level='AUTOCOMMIT') as engine:
text = "SELECT 1 FROM sys.databases WHERE name = '%s'" % database
url = _set_url_database(url, database='master')
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
try:
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
return False
else:
else:
with _create_engine(url) as engine:
text = 'SELECT 1'
try:
engine = sa.create_engine(url)
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
return False
finally:
if engine:
engine.dispose()


def create_database(url, encoding='utf8', template=None):
Expand Down Expand Up @@ -544,53 +568,43 @@ def create_database(url, encoding='utf8', template=None):
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver

if dialect_name == 'postgresql':
url = _set_url_database(url, database='postgres')
elif dialect_name == 'mssql':
url = _set_url_database(url, database='master')
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database='defaultdb')
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)

if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) or (
dialect_name == 'postgresql'
and dialect_driver
in {'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}
):
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
engine_kwargs = {'isolation_level': 'AUTOCOMMIT'}
else:
engine = sa.create_engine(url)
engine_kwargs = {}

if dialect_name == 'postgresql':
if not template:
template = 'template1'

with engine.begin() as conn:
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
quote(conn, database), encoding, quote(conn, template)
)
conn.execute(sa.text(text))
with _create_engine(url, use_primary_db=True, **engine_kwargs) as engine:
if dialect_name == 'postgresql':
if not template:
template = 'template1'

elif dialect_name == 'mysql':
with engine.begin() as conn:
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
quote(conn, database), encoding
)
conn.execute(sa.text(text))
with engine.begin() as conn:
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
quote(conn, database), encoding, quote(conn, template)
)
conn.execute(sa.text(text))

elif dialect_name == 'sqlite' and database != ':memory:':
if database:
elif dialect_name == 'mysql':
with engine.begin() as conn:
conn.execute(sa.text('CREATE TABLE DB(id int)'))
conn.execute(sa.text('DROP TABLE DB'))
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
quote(conn, database), encoding
)
conn.execute(sa.text(text))

else:
with engine.begin() as conn:
text = f'CREATE DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
elif dialect_name == 'sqlite' and database != ':memory:':
if database:
with engine.begin() as conn:
conn.execute(sa.text('CREATE TABLE DB(id int)'))
conn.execute(sa.text('DROP TABLE DB'))

engine.dispose()
else:
with engine.begin() as conn:
text = f'CREATE DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))


def drop_database(url):
Expand All @@ -613,30 +627,20 @@ def drop_database(url):
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver

if dialect_name == 'postgresql':
url = _set_url_database(url, database='postgres')
elif dialect_name == 'mssql':
url = _set_url_database(url, database='master')
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database='defaultdb')
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)

if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) or (
dialect_name == 'postgresql'
and dialect_driver
in {'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}
):
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
else:
engine = sa.create_engine(url)

if dialect_name == 'sqlite' and database != ':memory:':
if database:
os.remove(database)
else:
with engine.begin() as conn:
text = f'DROP DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) or (
dialect_name == 'postgresql'
and dialect_driver
in {'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}
):
engine_kwargs = {'isolation_level': 'AUTOCOMMIT'}
else:
engine_kwargs = {}

engine.dispose()
with _create_engine(url, use_primary_db=True, **engine_kwargs) as engine:
with engine.begin() as conn:
text = f'DROP DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
21 changes: 21 additions & 0 deletions tests/functions/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from sqlalchemy_utils import create_database, database_exists, drop_database
from sqlalchemy_utils.compat import get_sqlalchemy_version
from sqlalchemy_utils.functions.database import _create_engine

pymysql = None
try:
Expand Down Expand Up @@ -163,3 +164,23 @@ class TestDatabaseMssql(DatabaseTest):
def db_name(self):
pytest.importorskip('pyodbc')
return 'db_test_sqlalchemy_util'


def test_create_engine(sqlite_memory_dsn):
"""Test that engine creation context manager creates an engine and disposes of it"""
with _create_engine(sqlite_memory_dsn) as engine:
pool = engine.pool
with engine.connect() as conn:
assert conn.execute(sa.text('SELECT 1')).scalar() == 1

assert engine.pool is not pool, "Engine was not disposed because pool is the same"


def test_create_engine_always_disposes(sqlite_memory_dsn):
"""Test that engine creation context manager still disposes of an engine when an exception is raised."""
with pytest.raises(RuntimeError, match='it failed'):
with _create_engine(sqlite_memory_dsn) as engine:
pool = engine.pool
raise RuntimeError('it failed')

assert engine.pool is not pool, "Engine was not disposed because pool is the same"