Skip to content

Commit 8adeea8

Browse files
committed
Store acquired connection in context variable
1 parent d122a7a commit 8adeea8

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

mautrix/util/async_db/aiosqlite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async def stop(self) -> None:
181181
self._conns -= 1
182182
await conn.close()
183183

184-
def acquire(self) -> AsyncContextManager[LoggingConnection]:
184+
def acquire_direct(self) -> AsyncContextManager[LoggingConnection]:
185185
if self._parent:
186186
return self._parent.acquire()
187187
return self._acquire()

mautrix/util/async_db/asyncpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async def _handle_exception(self, err: Exception) -> None:
9696
sys.exit(26)
9797

9898
@asynccontextmanager
99-
async def acquire(self) -> LoggingConnection:
99+
async def acquire_direct(self) -> LoggingConnection:
100100
async with self.pool.acquire() as conn:
101101
yield LoggingConnection(
102102
self.scheme, conn, self.log, handle_exception=self._handle_exception

mautrix/util/async_db/database.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from typing import Any, AsyncContextManager, Type
99
from abc import ABC, abstractmethod
10+
from contextlib import asynccontextmanager
11+
from contextvars import ContextVar
1012
import logging
1113

1214
from yarl import URL
@@ -23,6 +25,8 @@
2325
from aiosqlite import Cursor
2426
from asyncpg import Record
2527

28+
conn_var: ContextVar[LoggingConnection | None] = ContextVar("db_connection", default=None)
29+
2630

2731
class Database(ABC):
2832
schemes: dict[str, Type[Database]] = {}
@@ -128,9 +132,22 @@ async def stop(self) -> None:
128132
pass
129133

130134
@abstractmethod
131-
def acquire(self) -> AsyncContextManager[LoggingConnection]:
135+
def acquire_direct(self) -> AsyncContextManager[LoggingConnection]:
132136
pass
133137

138+
@asynccontextmanager
139+
async def acquire(self) -> LoggingConnection:
140+
conn = conn_var.get(None)
141+
if conn is not None:
142+
yield conn
143+
return
144+
async with self.acquire_direct() as conn:
145+
token = conn_var.set(conn)
146+
try:
147+
yield conn
148+
finally:
149+
conn_var.reset(token)
150+
134151
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str | Cursor:
135152
async with self.acquire() as conn:
136153
return await conn.execute(query, *args, timeout=timeout)

0 commit comments

Comments
 (0)