Skip to content

Commit eec91fb

Browse files
committed
tests: lock
1 parent 1413954 commit eec91fb

File tree

2 files changed

+149
-3
lines changed

2 files changed

+149
-3
lines changed

src/tests/test_sync.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from __future__ import annotations
2+
3+
import threading
4+
import time
5+
6+
import anyio
7+
import pytest
8+
9+
import typed_diskcache
10+
11+
pytestmark = pytest.mark.anyio
12+
13+
14+
def test_lock(cache):
15+
state = {"num": 0}
16+
lock = typed_diskcache.SyncLock(cache, "demo")
17+
18+
def worker() -> None:
19+
state["num"] += 1
20+
with lock:
21+
assert lock.locked
22+
state["num"] += 1
23+
time.sleep(0.1)
24+
25+
with lock:
26+
thread = threading.Thread(target=worker)
27+
thread.start()
28+
time.sleep(0.1)
29+
assert state["num"] == 1
30+
thread.join()
31+
assert state["num"] == 2
32+
33+
34+
def test_rlock(cache):
35+
state = {"num": 0}
36+
rlock = typed_diskcache.SyncRLock(cache, "demo")
37+
38+
def worker() -> None:
39+
state["num"] += 1
40+
with rlock:
41+
with rlock:
42+
state["num"] += 1
43+
time.sleep(0.1)
44+
45+
with rlock:
46+
thread = threading.Thread(target=worker)
47+
thread.start()
48+
time.sleep(0.1)
49+
assert state["num"] == 1
50+
thread.join()
51+
assert state["num"] == 2
52+
53+
54+
def test_semaphore(cache):
55+
state = {"num": 0}
56+
semaphore = typed_diskcache.SyncSemaphore(cache, "demo", value=3)
57+
58+
def worker() -> None:
59+
state["num"] += 1
60+
with semaphore:
61+
state["num"] += 1
62+
time.sleep(0.1)
63+
64+
semaphore.acquire()
65+
semaphore.acquire()
66+
with semaphore:
67+
thread = threading.Thread(target=worker)
68+
thread.start()
69+
time.sleep(0.1)
70+
assert state["num"] == 1
71+
thread.join()
72+
assert state["num"] == 2
73+
semaphore.release()
74+
semaphore.release()
75+
76+
77+
async def test_async_lock(cache):
78+
state = {"num": 0}
79+
lock = typed_diskcache.AsyncLock(cache, "demo")
80+
81+
async def worker() -> None:
82+
state["num"] += 1
83+
async with lock:
84+
assert lock.locked
85+
state["num"] += 1
86+
await anyio.sleep(0.1)
87+
88+
async with lock:
89+
thread = threading.Thread(target=anyio.run, args=(worker,))
90+
thread.start()
91+
await anyio.sleep(0.1)
92+
assert state["num"] == 1
93+
thread.join()
94+
assert state["num"] == 2
95+
96+
97+
@pytest.mark.only
98+
async def test_async_rlock(cache):
99+
state = {"num": 0}
100+
rlock = typed_diskcache.AsyncRLock(cache, "demo")
101+
102+
async def worker() -> None:
103+
state["num"] += 1
104+
async with rlock:
105+
async with rlock:
106+
state["num"] += 1
107+
await anyio.sleep(0.1)
108+
109+
async with rlock:
110+
thread = threading.Thread(target=anyio.run, args=(worker,))
111+
thread.start()
112+
await anyio.sleep(0.1)
113+
assert state["num"] == 1
114+
thread.join()
115+
assert state["num"] == 2
116+
117+
118+
async def test_async_semaphore(cache):
119+
state = {"num": 0}
120+
semaphore = typed_diskcache.AsyncSemaphore(cache, "demo", value=3)
121+
122+
async def worker() -> None:
123+
state["num"] += 1
124+
async with semaphore:
125+
state["num"] += 1
126+
await anyio.sleep(0.1)
127+
128+
await semaphore.acquire()
129+
await semaphore.acquire()
130+
async with semaphore:
131+
thread = threading.Thread(target=anyio.run, args=(worker,))
132+
thread.start()
133+
await anyio.sleep(0.1)
134+
assert state["num"] == 1
135+
thread.join()
136+
assert state["num"] == 2
137+
await semaphore.release()
138+
await semaphore.release()

src/typed_diskcache/implement/sync/lock.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def acquire(self) -> None:
108108
self.key, None, expire=self.expire, tags=self.tags, retry=True
109109
)
110110
if added:
111-
break
111+
return
112112
time.sleep(SPIN_LOCK_SLEEP)
113113
timeout = time.monotonic() - start
114114

@@ -179,7 +179,11 @@ def acquire(self) -> None:
179179
self._cache.get, self.key, default=("default", 0)
180180
)
181181
container_value = validate_lock_value(container.value)
182-
if container.default or pid_tid == container_value[0]:
182+
if (
183+
container.default
184+
or pid_tid == container_value[0]
185+
or container_value[1] <= 0
186+
):
183187
value = 1 if container.default else container_value[1] + 1
184188
logger.debug("acquired lock: %s, value: %d", pid_tid, value)
185189
context.run(
@@ -393,7 +397,11 @@ async def acquire(self) -> None:
393397
self._cache.aget, self.key, default=("default", 0)
394398
)
395399
container_value = validate_lock_value(container.value)
396-
if container.default or pid_tid == container_value[0]:
400+
if (
401+
container.default
402+
or pid_tid == container_value[0]
403+
or container_value[1] <= 0
404+
):
397405
value = 1 if container.default else container_value[1] + 1
398406
logger.debug("acquired lock: %s, value: %d", pid_tid, value)
399407
await context.run(

0 commit comments

Comments
 (0)