Skip to content

Commit ffec501

Browse files
committed
improve: read write lock and less blocking
1 parent 71673ba commit ffec501

File tree

5 files changed

+109
-22
lines changed

5 files changed

+109
-22
lines changed

poetry.lock

Lines changed: 7 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "shared-lru-cache"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = ""
55
authors = ["Richard Löwenström <[email protected]>"]
66
readme = "README.md"

shared_lru_cache/read_write_lock.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from __future__ import annotations
2+
3+
from multiprocessing import Condition, Lock, Manager, Value
4+
5+
6+
class ReadWriteLock:
7+
readers: Value[int]
8+
readers_lock: Lock
9+
writers_lock: Lock
10+
readers_ok: Condition
11+
12+
def __init__(self, manager: Manager):
13+
self.readers = manager.Value("i", 0)
14+
self.readers_lock = manager.Lock()
15+
self.writers_lock = manager.Lock()
16+
self.readers_ok = manager.Condition(self.readers_lock)
17+
18+
def read_lock(self):
19+
return ReadLock(self)
20+
21+
def write_lock(self):
22+
return WriteLock(self)
23+
24+
25+
class ReadLock:
26+
rw_lock: ReadWriteLock
27+
28+
def __init__(self, rw_lock):
29+
self.rw_lock = rw_lock
30+
31+
def __enter__(self):
32+
with self.rw_lock.readers_lock:
33+
self.rw_lock.readers.value += 1
34+
if self.rw_lock.readers.value == 1:
35+
self.rw_lock.writers_lock.acquire()
36+
37+
def __exit__(self, exc_type, exc_val, exc_tb):
38+
with self.rw_lock.readers_lock:
39+
self.rw_lock.readers.value -= 1
40+
if self.rw_lock.readers.value == 0:
41+
self.rw_lock.writers_lock.release()
42+
43+
44+
class WriteLock:
45+
rw_lock: ReadWriteLock
46+
47+
def __init__(self, rw_lock):
48+
self.rw_lock = rw_lock
49+
50+
def __enter__(self):
51+
self.rw_lock.writers_lock.acquire()
52+
53+
def __exit__(self, exc_type, exc_val, exc_tb):
54+
self.rw_lock.writers_lock.release()

shared_lru_cache/shared_lru_cache.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
import pickle
33
from multiprocessing import Manager
44

5+
import numpy as np
6+
import torch
7+
8+
from .read_write_lock import ReadWriteLock
9+
510

611
class SharedLRUCache:
712
def __init__(self, maxsize=128):
813
self.maxsize = maxsize
914
self.manager = Manager()
1015
self.cache = self.manager.dict()
11-
self.read_lock = self.manager.RLock()
12-
self.write_lock = self.manager.Lock()
16+
self.lock = ReadWriteLock(self.manager)
1317
self.order = self.manager.list()
1418
self.data_store = self.manager.dict()
1519

@@ -18,26 +22,26 @@ def __call__(self, func):
1822
def wrapper(*args, **kwargs):
1923
key = str((args, frozenset(kwargs.items())))
2024

21-
with self.read_lock:
25+
with self.lock.read_lock():
2226
if key in self.cache:
23-
with self.write_lock:
24-
if (
25-
key in self.order
26-
): # Check if key is in order before removing
27-
self.order.remove(key)
28-
self.order.append(key)
29-
return pickle.loads(self.data_store[key])
27+
hit = True
28+
serialized_result, obj_info = self.data_store[key]
29+
else:
30+
hit = False
31+
32+
if hit:
33+
return self.deserialize(serialized_result, obj_info)
3034

3135
result = func(*args, **kwargs)
32-
serialized_result = pickle.dumps(result)
36+
serialized_result, obj_info = self.serialize(result)
3337

34-
with self.write_lock:
35-
# Check again in case another process has updated the cache
38+
with self.lock.read_lock():
3639
if key in self.cache:
3740
return result
3841

42+
with self.lock.write_lock():
3943
self.cache[key] = key
40-
self.data_store[key] = serialized_result
44+
self.data_store[key] = (serialized_result, obj_info)
4145
self.order.append(key)
4246

4347
while len(self.order) > self.maxsize:
@@ -52,6 +56,32 @@ def wrapper(*args, **kwargs):
5256
wrapper.data_store = self.data_store
5357
return wrapper
5458

59+
def serialize(self, obj):
60+
if isinstance(obj, np.ndarray):
61+
obj_info = ("numpy", obj.shape, obj.dtype.str)
62+
return obj.tobytes(), obj_info
63+
elif isinstance(obj, torch.Tensor):
64+
obj.byte()
65+
numpy_array = obj.cpu().numpy()
66+
obj_info = ("torch", numpy_array.shape, numpy_array.dtype.str)
67+
return numpy_array.tobytes(), obj_info
68+
else:
69+
obj_info = ("other",)
70+
return pickle.dumps(obj), obj_info
71+
72+
def deserialize(self, data, obj_info):
73+
obj_type, *info = obj_info
74+
if obj_type == "numpy":
75+
shape, dtype = info
76+
return np.frombuffer(data, dtype=np.dtype(dtype)).reshape(shape)
77+
elif obj_type == "torch":
78+
shape, dtype = info
79+
dtype = np.dtype(dtype) if isinstance(dtype, str) else dtype
80+
numpy_array = np.frombuffer(data, dtype=dtype).reshape(shape)
81+
return torch.from_numpy(numpy_array)
82+
else:
83+
return pickle.loads(data)
84+
5585

5686
def shared_lru_cache(maxsize=128):
5787
return SharedLRUCache(maxsize)

tests/test_pytorch_data_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
from shared_lru_cache import shared_lru_cache
99

10-
MAX_INDEX = 20
10+
MAX_INDEX = 10
1111

1212

1313
def load_image(idx):
14-
time.sleep(0.2) # Simulate some data loading time
15-
return torch.randn(1024, 1024)
14+
time.sleep(5) # Simulate some data loading time
15+
return torch.zeros((30000, 2500), dtype=torch.uint8)
1616

1717

1818
lru_cached_load_image = lru_cache(maxsize=128)(load_image)

0 commit comments

Comments
 (0)