Skip to content

Commit 2cf6677

Browse files
committed
fix: sync connections
1 parent 1190aac commit 2cf6677

File tree

2 files changed

+50
-19
lines changed

2 files changed

+50
-19
lines changed

src/typed_diskcache/implement/sync/lock.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from typed_diskcache import exception as te
1313
from typed_diskcache.core.const import DEFAULT_LOCK_TIMEOUT, SPIN_LOCK_SLEEP
14-
from typed_diskcache.core.context import context
14+
from typed_diskcache.core.context import context, enter_connection
1515
from typed_diskcache.database.connect import transact
1616
from typed_diskcache.interface.sync import AsyncLockProtocol, SyncLockProtocol
1717
from typed_diskcache.log import get_logger
@@ -124,13 +124,20 @@ def acquire(self) -> None:
124124
while timeout < self.timeout:
125125
sa_conn = stack.enter_context(self._cache.conn.connect())
126126
stack.enter_context(transact(sa_conn))
127-
container = self._cache.get(self.key, default=("default", 0))
127+
context = stack.enter_context(enter_connection(sa_conn))
128+
container = context.run(
129+
self._cache.get, self.key, default=("default", 0)
130+
)
128131
container_value = validate_lock_value(container.value)
129132
if container.default or pid_tid == container_value[0]:
130133
value = 1 if container.default else container_value[1] + 1
131134
logger.debug("acquired lock: %s, value: %d", pid_tid, value)
132-
self._cache.set(
133-
self.key, (pid_tid, value), expire=self.expire, tags=self.tags
135+
context.run(
136+
self._cache.set,
137+
self.key,
138+
(pid_tid, value),
139+
expire=self.expire,
140+
tags=self.tags,
134141
)
135142
return
136143
stack.close()
@@ -149,7 +156,8 @@ def release(self) -> None:
149156
with ExitStack() as stack:
150157
sa_conn = stack.enter_context(self._cache.conn.connect())
151158
stack.enter_context(transact(sa_conn))
152-
container = self._cache.get(self.key, default=("default", 0))
159+
context = stack.enter_context(enter_connection(sa_conn))
160+
container = context.run(self._cache.get, self.key, default=("default", 0))
153161
container_value = validate_lock_value(container.value)
154162
if (
155163
container.default
@@ -163,7 +171,8 @@ def release(self) -> None:
163171
container_value,
164172
)
165173
raise te.TypedDiskcacheRuntimeError("cannot release un-acquired lock")
166-
self._cache.set(
174+
context.run(
175+
self._cache.set,
167176
self.key,
168177
(container_value[0], container_value[1] - 1),
169178
expire=self.expire,
@@ -273,12 +282,16 @@ async def acquire(self) -> None:
273282
self._cache.conn.aconnect()
274283
)
275284
await sub_stack.enter_async_context(transact(sa_conn))
276-
container = await self._cache.aget(self.key, default=("default", 0))
285+
context = stack.enter_context(enter_connection(sa_conn))
286+
container = await context.run(
287+
self._cache.aget, self.key, default=("default", 0)
288+
)
277289
container_value = validate_lock_value(container.value)
278290
if container.default or pid_tid == container_value[0]:
279291
value = 1 if container.default else container_value[1] + 1
280292
logger.debug("acquired lock: %s, value: %d", pid_tid, value)
281-
await self._cache.aset(
293+
await context.run(
294+
self._cache.aset,
282295
self.key,
283296
(pid_tid, value),
284297
expire=self.expire,

src/typed_diskcache/implement/sync/semaphore.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from typed_diskcache import exception as te
1111
from typed_diskcache.core.const import DEFAULT_LOCK_TIMEOUT, SPIN_LOCK_SLEEP
12-
from typed_diskcache.core.context import context
12+
from typed_diskcache.core.context import context, enter_connection
1313
from typed_diskcache.database.connect import transact
1414
from typed_diskcache.interface.sync import AsyncSemaphoreProtocol, SyncSemaphoreProtocol
1515
from typed_diskcache.log import get_logger
@@ -82,10 +82,12 @@ def acquire(self) -> None:
8282
while timeout < self.timeout:
8383
sa_conn = stack.enter_context(self._cache.conn.connect())
8484
stack.enter_context(transact(sa_conn))
85-
container = self._cache.get(self.key, default=self._value)
85+
context = stack.enter_context(enter_connection(sa_conn))
86+
container = context.run(self._cache.get, self.key, default=self._value)
8687
container_value = validate_semaphore_value(container.value)
8788
if container_value > 0:
88-
self._cache.set(
89+
context.run(
90+
self._cache.set,
8991
self.key,
9092
container_value - 1,
9193
expire=self.expire,
@@ -104,7 +106,8 @@ def release(self) -> None:
104106
with ExitStack() as stack:
105107
sa_conn = stack.enter_context(self._cache.conn.connect())
106108
stack.enter_context(transact(sa_conn))
107-
container = self._cache.get(self.key, default=self._value)
109+
context = stack.enter_context(enter_connection(sa_conn))
110+
container = context.run(self._cache.get, self.key, default=self._value)
108111
container_value = validate_semaphore_value(container.value)
109112
if self._value <= container_value:
110113
logger.error(
@@ -115,8 +118,12 @@ def release(self) -> None:
115118
raise te.TypedDiskcacheRuntimeError(
116119
"cannot release un-acquired semaphore"
117120
)
118-
self._cache.set(
119-
self.key, container_value + 1, expire=self.expire, tags=self.tags
121+
context.run(
122+
self._cache.set,
123+
self.key,
124+
container_value + 1,
125+
expire=self.expire,
126+
tags=self.tags,
120127
)
121128

122129
@override
@@ -194,10 +201,14 @@ async def acquire(self) -> None:
194201
self._cache.conn.aconnect()
195202
)
196203
await sub_stack.enter_async_context(transact(sa_conn))
197-
container = await self._cache.aget(self.key, default=self._value)
204+
context = stack.enter_context(enter_connection(sa_conn))
205+
container = await context.run(
206+
self._cache.aget, self.key, default=self._value
207+
)
198208
container_value = validate_semaphore_value(container.value)
199209
if container_value > 0:
200-
await self._cache.aset(
210+
await context.run(
211+
self._cache.aset,
201212
self.key,
202213
container_value - 1,
203214
expire=self.expire,
@@ -215,7 +226,10 @@ async def release(self) -> None:
215226
async with AsyncExitStack() as stack:
216227
sa_conn = await stack.enter_async_context(self._cache.conn.aconnect())
217228
await stack.enter_async_context(transact(sa_conn))
218-
container = await self._cache.aget(self.key, default=self._value)
229+
context = stack.enter_context(enter_connection(sa_conn))
230+
container = await context.run(
231+
self._cache.aget, self.key, default=self._value
232+
)
219233
container_value = validate_semaphore_value(container.value)
220234
if self._value <= container_value:
221235
logger.error(
@@ -226,8 +240,12 @@ async def release(self) -> None:
226240
raise te.TypedDiskcacheRuntimeError(
227241
"cannot release un-acquired semaphore"
228242
)
229-
await self._cache.aset(
230-
self.key, container_value + 1, expire=self.expire, tags=self.tags
243+
await context.run(
244+
self._cache.aset,
245+
self.key,
246+
container_value + 1,
247+
expire=self.expire,
248+
tags=self.tags,
231249
)
232250

233251
@override

0 commit comments

Comments
 (0)