Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit bf81282

Browse files
Merge pull request #38 from encode/ensure-global-rollback-initialized-on-init
Ensure global connections are created on __init__.
2 parents b409ea7 + d86b2c3 commit bf81282

File tree

5 files changed

+23
-19
lines changed

5 files changed

+23
-19
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
*.pyc
2+
test.db
23
.coverage
34
.pytest_cache/
45
.mypy_cache/

databases/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from databases.core import Database, DatabaseURL
22

33

4-
__version__ = "0.1.3"
4+
__version__ = "0.1.4"
55
__all__ = ["Database", "DatabaseURL"]

databases/backends/mysql.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ async def disconnect(self) -> None:
4040
self._pool = None
4141

4242
def connection(self) -> "MySQLConnection":
43-
assert self._pool is not None, "DatabaseBackend is not running"
44-
return MySQLConnection(self._pool, self._dialect)
43+
return MySQLConnection(self, self._dialect)
4544

4645

4746
class CompilationContext:
@@ -50,18 +49,20 @@ def __init__(self, context: ExecutionContext):
5049

5150

5251
class MySQLConnection(ConnectionBackend):
53-
def __init__(self, pool: aiomysql.pool.Pool, dialect: Dialect):
54-
self._pool = pool
52+
def __init__(self, database: MySQLBackend, dialect: Dialect):
53+
self._database = database
5554
self._dialect = dialect
5655
self._connection = None # type: typing.Optional[aiomysql.Connection]
5756

5857
async def acquire(self) -> None:
5958
assert self._connection is None, "Connection is already acquired"
60-
self._connection = await self._pool.acquire()
59+
assert self._database._pool is not None, "DatabaseBackend is not running"
60+
self._connection = await self._database._pool.acquire()
6161

6262
async def release(self) -> None:
6363
assert self._connection is not None, "Connection is not acquired"
64-
await self._pool.release(self._connection)
64+
assert self._database._pool is not None, "DatabaseBackend is not running"
65+
await self._database._pool.release(self._connection)
6566
self._connection = None
6667

6768
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:

databases/backends/postgres.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ async def disconnect(self) -> None:
4343
self._pool = None
4444

4545
def connection(self) -> "PostgresConnection":
46-
assert self._pool is not None, "DatabaseBackend is not running"
47-
return PostgresConnection(self._pool, self._dialect)
46+
return PostgresConnection(self, self._dialect)
4847

4948

5049
class Record:
@@ -72,18 +71,20 @@ def __getitem__(self, key: str) -> typing.Any:
7271

7372

7473
class PostgresConnection(ConnectionBackend):
75-
def __init__(self, pool: asyncpg.pool.Pool, dialect: Dialect):
76-
self._pool = pool
74+
def __init__(self, database: PostgresBackend, dialect: Dialect):
75+
self._database = database
7776
self._dialect = dialect
7877
self._connection = None # type: typing.Optional[asyncpg.connection.Connection]
7978

8079
async def acquire(self) -> None:
8180
assert self._connection is None, "Connection is already acquired"
82-
self._connection = await self._pool.acquire()
81+
assert self._database._pool is not None, "DatabaseBackend is not running"
82+
self._connection = await self._database._pool.acquire()
8383

8484
async def release(self) -> None:
8585
assert self._connection is not None, "Connection is not acquired"
86-
self._connection = await self._pool.release(self._connection)
86+
assert self._database._pool is not None, "DatabaseBackend is not running"
87+
self._connection = await self._database._pool.release(self._connection)
8788
self._connection = None
8889

8990
async def fetch_all(self, query: ClauseElement) -> typing.Any:

databases/core.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def __init__(
4444
self._global_connection = None # type: typing.Optional[Connection]
4545
self._global_transaction = None # type: typing.Optional[Transaction]
4646

47+
if self._force_rollback:
48+
self._global_connection = Connection(self._backend)
49+
self._global_transaction = self._global_connection.transaction(
50+
force_rollback=True
51+
)
52+
4753
async def connect(self) -> None:
4854
"""
4955
Establish the connection pool.
@@ -54,10 +60,7 @@ async def connect(self) -> None:
5460
self.is_connected = True
5561

5662
if self._force_rollback:
57-
self._global_connection = Connection(self._backend)
58-
self._global_transaction = self._global_connection.transaction(
59-
force_rollback=True
60-
)
63+
assert self._global_transaction is not None
6164
await self._global_transaction.__aenter__()
6265

6366
async def disconnect(self) -> None:
@@ -69,8 +72,6 @@ async def disconnect(self) -> None:
6972
if self._force_rollback:
7073
assert self._global_transaction is not None
7174
await self._global_transaction.__aexit__()
72-
self._global_transaction = None
73-
self._global_connection = None
7475

7576
await self._backend.disconnect()
7677
self.is_connected = False

0 commit comments

Comments
 (0)