Skip to content

Commit 472ad21

Browse files
committed
fix: sync
1 parent eec91fb commit 472ad21

File tree

7 files changed

+130
-57
lines changed

7 files changed

+130
-57
lines changed

src/tests/test_sync.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ async def worker() -> None:
9494
assert state["num"] == 2
9595

9696

97-
@pytest.mark.only
9897
async def test_async_rlock(cache):
9998
state = {"num": 0}
10099
rlock = typed_diskcache.AsyncRLock(cache, "demo")

src/typed_diskcache/database/connect.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
)
1010
from typing import TYPE_CHECKING, Any, Callable, Protocol, overload, runtime_checkable
1111

12-
import anyio
1312
import sqlalchemy as sa
1413
from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
1514
from sqlalchemy.engine import Connection, Engine, create_engine
@@ -53,7 +52,6 @@
5352

5453
_TIMEOUT = 10
5554
_TIMEOUT_MS = _TIMEOUT * 1000
56-
_LOCK = anyio.Lock()
5755

5856
logger = get_logger()
5957

@@ -243,6 +241,7 @@ def ensure_sqlite_async_engine(
243241
def sync_transact(conn: SyncConnT) -> Generator[SyncConnT, None, None]:
244242
is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False)
245243
if is_begin is False:
244+
logger.debug("enter transaction, session: `%d`", id(conn))
246245
conn.execute(sa.text("BEGIN IMMEDIATE;"))
247246
conn.info[CONNECTION_BEGIN_INFO_KEY] = True
248247

@@ -252,24 +251,26 @@ def sync_transact(conn: SyncConnT) -> Generator[SyncConnT, None, None]:
252251
conn.rollback()
253252
raise
254253
finally:
254+
logger.debug("exit transaction, session: `%d`", id(conn))
255255
with suppress(ResourceClosedError):
256256
conn.info[CONNECTION_BEGIN_INFO_KEY] = False
257257

258258

259259
@asynccontextmanager
260260
async def async_transact(conn: AsyncConnT) -> AsyncGenerator[AsyncConnT, None]:
261-
async with _LOCK:
262-
is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False)
263-
if is_begin is False:
264-
await conn.execute(sa.text("BEGIN IMMEDIATE;"))
265-
conn.info[CONNECTION_BEGIN_INFO_KEY] = True
261+
is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False)
262+
if is_begin is False:
263+
logger.debug("enter transaction, session: `%d`", id(conn))
264+
await conn.execute(sa.text("BEGIN IMMEDIATE;"))
265+
conn.info[CONNECTION_BEGIN_INFO_KEY] = True
266266

267267
try:
268268
yield conn
269269
except Exception:
270270
await conn.rollback()
271271
raise
272272
finally:
273+
logger.debug("exit transaction, session: `%d`", id(conn))
273274
with suppress(ResourceClosedError):
274275
conn.info[CONNECTION_BEGIN_INFO_KEY] = False
275276

src/typed_diskcache/database/connection.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typed_diskcache.core.types import EvictionPolicy
1717
from typed_diskcache.database import connect as db_connect
1818
from typed_diskcache.database.model import Cache
19+
from typed_diskcache.log import get_logger
1920

2021
if TYPE_CHECKING:
2122
from collections.abc import AsyncGenerator, Generator, Mapping
@@ -28,6 +29,8 @@
2829

2930
__all__ = ["Connection"]
3031

32+
logger = get_logger()
33+
3134

3235
class Connection:
3336
"""Database connection."""
@@ -112,40 +115,72 @@ def _async_engine(self) -> AsyncEngine:
112115
return db_connect.set_listeners(engine, self._settings.sqlite_settings)
113116

114117
@contextmanager
115-
def _connect(self) -> Generator[SAConnection, None, None]:
118+
def _connect(self, *, stacklevel: int = 1) -> Generator[SAConnection, None, None]:
116119
with self._sync_engine.connect() as connection:
120+
logger.debug(
121+
"Creating connection: `%d`", id(connection), stacklevel=stacklevel
122+
)
117123
yield connection
124+
logger.debug(
125+
"Closing connection: `%d`", id(connection), stacklevel=stacklevel
126+
)
118127

119128
@contextmanager
120-
def session(self) -> Generator[Session, None, None]:
129+
def session(self, *, stacklevel: int = 1) -> Generator[Session, None, None]:
121130
"""Connect to the database."""
122131
session = self._context.get()
123132
if session is not None:
133+
logger.debug("Reusing session: `%d`", id(session), stacklevel=stacklevel)
124134
yield session
125135
return
126136

127-
with self._connect() as connection:
137+
with self._connect(stacklevel=stacklevel + 2) as connection:
128138
with Session(connection, autoflush=False) as session:
139+
logger.debug(
140+
"Creating session: `%d`", id(session), stacklevel=stacklevel
141+
)
129142
yield session
143+
logger.debug(
144+
"Closing session: `%d`", id(session), stacklevel=stacklevel
145+
)
130146

131147
@asynccontextmanager
132-
async def _aconnect(self) -> AsyncGenerator[AsyncConnection, None]:
148+
async def _aconnect(
149+
self, *, stacklevel: int = 1
150+
) -> AsyncGenerator[AsyncConnection, None]:
133151
"""Connect to the database."""
134152
async with self._async_engine.connect() as connection:
153+
logger.debug(
154+
"Creating async connection: `%d`", id(connection), stacklevel=stacklevel
155+
)
135156
yield connection
157+
logger.debug(
158+
"Closing async connection: `%d`", id(connection), stacklevel=stacklevel
159+
)
136160

137161
@asynccontextmanager
138-
async def asession(self) -> AsyncGenerator[AsyncSession, None]:
162+
async def asession(
163+
self, *, stacklevel: int = 1
164+
) -> AsyncGenerator[AsyncSession, None]:
139165
"""Connect to the database."""
140166
session = self._acontext.get()
141167
if session is not None:
168+
logger.debug(
169+
"Reusing async session: `%d`", id(session), stacklevel=stacklevel
170+
)
142171
await anyio.lowlevel.checkpoint()
143172
yield session
144173
return
145174

146-
async with self._aconnect() as connection:
175+
async with self._aconnect(stacklevel=stacklevel + 2) as connection:
147176
async with AsyncSession(connection, autoflush=False) as session:
177+
logger.debug(
178+
"Creating async session: `%d`", id(session), stacklevel=stacklevel
179+
)
148180
yield session
181+
logger.debug(
182+
"Closing async session: `%d`", id(session), stacklevel=stacklevel
183+
)
149184

150185
def close(self) -> None:
151186
"""Close the connection."""
@@ -195,6 +230,11 @@ def enter_session(
195230
context_var = (
196231
self._acontext if isinstance(session, AsyncSession) else self._context
197232
)
233+
logger.debug(
234+
"Entering session context: `%s`, session: `%d`",
235+
context_var.name,
236+
id(session),
237+
)
198238
with enter_session(session, context_var) as context: # pyright: ignore[reportArgumentType]
199239
yield context
200240

src/typed_diskcache/implement/cache/default/main.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
@context("Cache.length")
109109
@override
110110
def __len__(self) -> int:
111-
with self.conn.session() as session:
111+
with self.conn.session(stacklevel=4) as session:
112112
return session.scalars(
113113
sa.select(Metadata.value).where(Metadata.key == MetadataKey.COUNT)
114114
).one()
@@ -128,7 +128,7 @@ def __getitem__(self, key: Any) -> Container[Any]:
128128
@override
129129
def __contains__(self, key: Any) -> bool:
130130
db_key, raw = self.disk.put(key)
131-
with self.conn.session() as session:
131+
with self.conn.session(stacklevel=4) as session:
132132
row = session.scalars(
133133
sa.select(CacheTable.id).where(
134134
CacheTable.key == db_key,
@@ -241,7 +241,7 @@ def get(
241241
and self.settings.eviction_policy == EvictionPolicy.NONE
242242
):
243243
logger.debug("Cache statistics disabled or eviction policy is NONE")
244-
with self.conn.session() as session:
244+
with self.conn.session(stacklevel=4) as session:
245245
row = session.scalars(
246246
select_stmt, {"expire_time": time.time()}
247247
).one_or_none()
@@ -328,7 +328,7 @@ async def aget(
328328
and self.settings.eviction_policy == EvictionPolicy.NONE
329329
):
330330
logger.debug("Cache statistics disabled or eviction policy is NONE")
331-
async with self.conn.asession() as session:
331+
async with self.conn.asession(stacklevel=4) as session:
332332
row_fetch = await session.scalars(
333333
select_stmt, {"expire_time": time.time()}
334334
)
@@ -685,7 +685,7 @@ async def _async_cull(
685685
@context
686686
@override
687687
def volume(self) -> int:
688-
with self.conn.session() as session:
688+
with self.conn.session(stacklevel=4) as session:
689689
page_count: int = session.execute(
690690
sa.text("PRAGMA page_count;")
691691
).scalar_one()
@@ -698,7 +698,7 @@ def volume(self) -> int:
698698
@context
699699
@override
700700
async def avolume(self) -> int:
701-
async with self.conn.asession() as session:
701+
async with self.conn.asession(stacklevel=4) as session:
702702
page_count_fetch = await session.execute(sa.text("PRAGMA page_count;"))
703703
page_count: int = page_count_fetch.scalar_one()
704704
size_fetch = await session.scalars(
@@ -1068,7 +1068,7 @@ def filter(
10681068
lower_bound = 0
10691069
tags_count = len(tags)
10701070
while True:
1071-
with self.conn.session() as session:
1071+
with self.conn.session(stacklevel=4) as session:
10721072
rows = session.execute(
10731073
stmt,
10741074
{
@@ -1105,7 +1105,7 @@ async def afilter(
11051105
lower_bound = 0
11061106
tags_count = len(tags)
11071107
while True:
1108-
async with self.conn.asession() as session:
1108+
async with self.conn.asession(stacklevel=4) as session:
11091109
rows_fetch = await session.execute(
11101110
stmt,
11111111
{
@@ -1941,7 +1941,7 @@ async def acheck(
19411941
@override
19421942
def iterkeys(self, *, reverse: bool = False) -> Generator[Any, None, None]:
19431943
select_stmt, iter_stmt = default_utils.prepare_iterkeys_stmt(reverse=reverse)
1944-
with self.conn.session() as session:
1944+
with self.conn.session(stacklevel=4) as session:
19451945
row = session.execute(select_stmt).one_or_none()
19461946

19471947
if not row:
@@ -1965,7 +1965,7 @@ def iterkeys(self, *, reverse: bool = False) -> Generator[Any, None, None]:
19651965
@override
19661966
async def aiterkeys(self, *, reverse: bool = False) -> AsyncGenerator[Any, None]:
19671967
select_stmt, iter_stmt = default_utils.prepare_iterkeys_stmt(reverse=reverse)
1968-
async with self.conn.asession() as session:
1968+
async with self.conn.asession(stacklevel=4) as session:
19691969
row_fetch = await session.execute(select_stmt)
19701970
row = row_fetch.one_or_none()
19711971

src/typed_diskcache/implement/cache/default/utils.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,17 @@ def prepare_cull_stmt(
152152
return filenames_select_stmt, filenames_delete_stmt, select_stmt
153153

154154

155-
def transact_process(
155+
def transact_process( # noqa: PLR0913
156156
stack: ExitStack,
157157
conn: Connection,
158158
disk: DiskProtocol,
159159
*,
160160
retry: bool = False,
161161
filename: str | PathLike[str] | None = None,
162+
stacklevel: int = 3,
162163
) -> Session | None:
163164
try:
164-
session = stack.enter_context(conn.session())
165+
session = stack.enter_context(conn.session(stacklevel=stacklevel))
165166
session = stack.enter_context(database_transact(session))
166167
except OperationalError as exc:
167168
stack.close()
@@ -174,16 +175,17 @@ def transact_process(
174175
return session
175176

176177

177-
async def async_transact_process(
178+
async def async_transact_process( # noqa: PLR0913
178179
stack: AsyncExitStack,
179180
conn: Connection,
180181
disk: DiskProtocol,
181182
*,
182183
retry: bool = False,
183184
filename: str | PathLike[str] | None = None,
185+
stacklevel: int = 3,
184186
) -> AsyncSession | None:
185187
try:
186-
session = await stack.enter_async_context(conn.asession())
188+
session = await stack.enter_async_context(conn.asession(stacklevel=stacklevel))
187189
session = await stack.enter_async_context(database_transact(session))
188190
except OperationalError as exc:
189191
await stack.aclose()
@@ -218,7 +220,7 @@ def iter_disk(
218220
)
219221

220222
while True:
221-
with conn.session() as session:
223+
with conn.session(stacklevel=4) as session:
222224
rows = session.execute(
223225
stmt,
224226
{"left_bound": rowid, "right_bound": bound}
@@ -254,7 +256,7 @@ async def aiter_disk(
254256
)
255257

256258
while True:
257-
async with conn.asession() as session:
259+
async with conn.asession(stacklevel=4) as session:
258260
rows_fetch = await session.execute(
259261
stmt, {"left_bound": rowid, "right_bound": bound}
260262
)
@@ -422,7 +424,12 @@ def transact(
422424
while session is None:
423425
stack.close()
424426
session = transact_process(
425-
stack, conn, disk, retry=retry, filename=filename
427+
stack,
428+
conn,
429+
disk,
430+
retry=retry,
431+
filename=filename,
432+
stacklevel=stacklevel + 4,
426433
)
427434

428435
logger.debug("Enter transaction `%s`", filename, stacklevel=stacklevel)
@@ -461,12 +468,20 @@ async def async_transact(
461468
while session is None:
462469
await stack.aclose()
463470
session = await async_transact_process(
464-
stack, conn, disk, retry=retry, filename=filename
471+
stack,
472+
conn,
473+
disk,
474+
retry=retry,
475+
filename=filename,
476+
stacklevel=stacklevel + 4,
465477
)
466478

467479
logger.debug("Enter async transaction `%s`", filename, stacklevel=stacklevel)
468480
stack.callback(
469-
logger.debug, "Exit async transaction `%s`", filename, stacklevel=stacklevel
481+
logger.debug,
482+
"Exit async transaction `%s`",
483+
filename,
484+
stacklevel=stacklevel + 2,
470485
)
471486
try:
472487
stack.enter_context(receive)
@@ -599,12 +614,12 @@ def prepare_filter_stmt(
599614

600615

601616
def find_max_id(conn: Connection) -> int | None:
602-
with conn.session() as session:
617+
with conn.session(stacklevel=4) as session:
603618
return session.scalar(sa.select(sa.func.max(CacheTable.id)))
604619

605620

606621
async def async_find_max_id(conn: Connection) -> int | None:
607-
async with conn.asession() as session:
622+
async with conn.asession(stacklevel=4) as session:
608623
return await session.scalar(sa.select(sa.func.max(CacheTable.id)))
609624

610625

@@ -976,7 +991,7 @@ def prepare_iterkeys_stmt(
976991

977992

978993
async def acheck_integrity(*, conn: Connection, fix: bool, stacklevel: int = 2) -> None:
979-
async with conn.asession() as session:
994+
async with conn.asession(stacklevel=4) as session:
980995
integrity_fetch = await session.execute(sa.text("PRAGMA integrity_check;"))
981996
integrity = integrity_fetch.scalars().all()
982997

@@ -1168,7 +1183,7 @@ async def acheck_metadata_size(
11681183

11691184

11701185
def check_integrity(*, conn: Connection, fix: bool, stacklevel: int = 2) -> None:
1171-
with conn.session() as session:
1186+
with conn.session(stacklevel=4) as session:
11721187
integrity = session.execute(sa.text("PRAGMA integrity_check;")).scalars().all()
11731188

11741189
if len(integrity) != 1 or integrity[0] != "ok":

0 commit comments

Comments
 (0)