Skip to content

Commit 1190aac

Browse files
committed
feat: enter_connection
1 parent 79a5775 commit 1190aac

File tree

3 files changed

+68
-9
lines changed

3 files changed

+68
-9
lines changed

src/typed_diskcache/core/context.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import threading
55
from contextlib import contextmanager
6-
from contextvars import ContextVar, Token, copy_context
6+
from contextvars import Context, ContextVar, Token, copy_context
77
from functools import partial, wraps
88
from typing import TYPE_CHECKING, Any, overload
99

@@ -14,13 +14,26 @@
1414
if TYPE_CHECKING:
1515
from collections.abc import Callable, Generator
1616

17-
__all__ = ["log_context", "context"]
17+
from sqlalchemy.engine import Connection
18+
from sqlalchemy.ext.asyncio import AsyncConnection
19+
20+
__all__ = [
21+
"log_context",
22+
"conn_context",
23+
"aconn_context",
24+
"enter_connection",
25+
"context",
26+
]
1827

1928
_F = TypeVar("_F", bound="Callable[..., Any]", infer_variance=True)
2029

2130
log_context: ContextVar[tuple[str, int]] = ContextVar(
2231
"log_thread_context", default=(DEFAULT_LOG_CONTEXT, DEFAULT_LOG_THREAD)
2332
)
33+
conn_context: ContextVar[Connection | None] = ContextVar("conn_context", default=None)
34+
aconn_context: ContextVar[AsyncConnection | None] = ContextVar(
35+
"aconn_context", default=None
36+
)
2437

2538

2639
@contextmanager
@@ -59,6 +72,33 @@ def context(func_or_context: _F | str) -> _F | Callable[[_F], _F]:
5972
return partial(_context, name=func_or_context)
6073

6174

75+
@contextmanager
76+
def enter_connection(
77+
conn: Connection | AsyncConnection,
78+
) -> Generator[Context, None, None]:
79+
"""Enter the connection context.
80+
81+
Args:
82+
conn: The connection to enter.
83+
84+
Yields:
85+
Copy of the current context.
86+
"""
87+
from sqlalchemy.ext.asyncio import AsyncConnection
88+
89+
if isinstance(conn, AsyncConnection):
90+
token = aconn_context.set(conn)
91+
reset_context = aconn_context.reset
92+
else:
93+
token = conn_context.set(conn)
94+
reset_context = conn_context.reset
95+
context = copy_context()
96+
try:
97+
yield context
98+
finally:
99+
reset_context(token) # pyright: ignore[reportArgumentType]
100+
101+
62102
def _context(func: _F, *, name: str) -> _F:
63103
if inspect.iscoroutinefunction(func):
64104

src/typed_diskcache/database/connection.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from pathlib import Path
66
from typing import TYPE_CHECKING, Any
77

8+
import anyio.lowlevel
89
import sqlalchemy as sa
910

1011
from typed_diskcache import exception as te
12+
from typed_diskcache.core.context import aconn_context, conn_context
1113
from typed_diskcache.core.types import EvictionPolicy
1214
from typed_diskcache.database import connect as db_connect
1315
from typed_diskcache.database.model import Cache
@@ -127,12 +129,23 @@ def _async_engine(self) -> AsyncEngine:
127129
@contextmanager
128130
def connect(self) -> Generator[SAConnection, None, None]:
129131
"""Connect to the database."""
132+
conn = conn_context.get()
133+
if conn is not None:
134+
yield conn
135+
return
136+
130137
with self._sync_engine.connect() as connection:
131138
yield connection
132139

133140
@asynccontextmanager
134141
async def aconnect(self) -> AsyncGenerator[AsyncConnection, None]:
135142
"""Connect to the database."""
143+
conn = aconn_context.get()
144+
if conn is not None:
145+
await anyio.lowlevel.checkpoint()
146+
yield conn
147+
return
148+
136149
async with self._async_engine.connect() as connection:
137150
yield connection
138151

src/typed_diskcache/implement/cache/default/main.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typed_diskcache import exception as te
1616
from typed_diskcache.core.const import ENOVAL
17-
from typed_diskcache.core.context import context
17+
from typed_diskcache.core.context import context, enter_connection
1818
from typed_diskcache.core.types import (
1919
Container,
2020
EvictionPolicy,
@@ -736,9 +736,13 @@ def stats(self, *, enable: bool = True, reset: bool = False) -> Stats:
736736
.where(SettingsTable.key == SettingsKey.STATISTICS)
737737
.values(value=enable)
738738
)
739-
self.update_settings(self.settings.model_copy(update={"statistics": enable}))
739+
with enter_connection(sa_conn) as context:
740+
context.run(
741+
self.update_settings,
742+
self.settings.model_copy(update={"statistics": enable}),
743+
)
740744

741-
return stats
745+
return stats
742746

743747
@context
744748
@override
@@ -768,10 +772,12 @@ async def astats(self, *, enable: bool = True, reset: bool = False) -> Stats:
768772
.where(SettingsTable.key == SettingsKey.STATISTICS)
769773
.values(value=enable)
770774
)
771-
await self.aupdate_settings(
772-
self.settings.model_copy(update={"statistics": enable})
773-
)
774-
return stats
775+
with enter_connection(sa_conn) as context:
776+
await context.run(
777+
self.aupdate_settings,
778+
self.settings.model_copy(update={"statistics": enable}),
779+
)
780+
return stats
775781

776782
@override
777783
def close(self) -> None:

0 commit comments

Comments
 (0)