Skip to content

Commit 039030f

Browse files
committed
V5.0.1
1 parent 7a21c41 commit 039030f

File tree

1 file changed

+103
-82
lines changed

1 file changed

+103
-82
lines changed

tests/thread_safety/test_thread_safety_patterns.py

Lines changed: 103 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ class TestAdvancedThreadSafetyPatterns:
1818
def test_producer_consumer_pattern(self):
1919
"""Test thread safety in producer-consumer pattern."""
2020

21-
@auto_thread_safe(['put', 'get'])
2221
class ThreadSafeQueue:
2322
def __init__(self, maxsize=10):
2423
self.queue = []
@@ -28,20 +27,25 @@ def __init__(self, maxsize=10):
2827
def put(self, item):
2928
with self.condition:
3029
while len(self.queue) >= self.maxsize:
31-
self.condition.wait()
30+
self.condition.wait(timeout=1.0) # Add timeout to prevent infinite wait
31+
if len(self.queue) >= self.maxsize:
32+
raise TimeoutError("Queue full timeout")
3233
self.queue.append(item)
3334
self.condition.notify_all()
3435

3536
def get(self):
3637
with self.condition:
3738
while not self.queue:
38-
self.condition.wait()
39+
self.condition.wait(timeout=1.0) # Add timeout to prevent infinite wait
40+
if not self.queue:
41+
raise TimeoutError("Queue empty timeout")
3942
item = self.queue.pop(0)
4043
self.condition.notify_all()
4144
return item
4245

4346
def size(self):
44-
return len(self.queue)
47+
with self.condition:
48+
return len(self.queue)
4549

4650
queue = ThreadSafeQueue(maxsize=5)
4751
results = []
@@ -84,29 +88,32 @@ def consumer(count):
8488
def test_reader_writer_pattern(self):
8589
"""Test thread safety in reader-writer pattern."""
8690

87-
@auto_thread_safe(['read', 'write'])
8891
class ThreadSafeDataStore:
8992
def __init__(self):
9093
self.data = {}
9194
self.read_count = 0
9295
self.write_count = 0
96+
self._lock = threading.RLock()
9397

9498
def read(self, key):
95-
self.read_count += 1
96-
time.sleep(0.001) # Simulate read time
97-
return self.data.get(key)
99+
with self._lock:
100+
self.read_count += 1
101+
time.sleep(0.001) # Simulate read time
102+
return self.data.get(key)
98103

99104
def write(self, key, value):
100-
self.write_count += 1
101-
time.sleep(0.001) # Simulate write time
102-
self.data[key] = value
105+
with self._lock:
106+
self.write_count += 1
107+
time.sleep(0.001) # Simulate write time
108+
self.data[key] = value
103109

104110
def get_stats(self):
105-
return {
106-
'reads': self.read_count,
107-
'writes': self.write_count,
108-
'data_size': len(self.data)
109-
}
111+
with self._lock:
112+
return {
113+
'reads': self.read_count,
114+
'writes': self.write_count,
115+
'data_size': len(self.data)
116+
}
110117

111118
store = ThreadSafeDataStore()
112119

@@ -194,14 +201,14 @@ def get_singleton():
194201
def test_resource_pool_pattern(self):
195202
"""Test thread-safe resource pool pattern."""
196203

197-
@auto_thread_safe(['get_resource', 'return_resource'])
198204
class ThreadSafeResourcePool:
199205
def __init__(self, create_resource_func, pool_size=5):
200206
self.create_resource = create_resource_func
201207
self.pool_size = pool_size
202208
self.available = []
203209
self.in_use = set()
204210
self.condition = threading.Condition()
211+
self._lock = threading.RLock()
205212

206213
# Pre-populate pool
207214
for _ in range(pool_size):
@@ -227,11 +234,12 @@ def return_resource(self, resource):
227234
self.condition.notify()
228235

229236
def stats(self):
230-
return {
231-
'available': len(self.available),
232-
'in_use': len(self.in_use),
233-
'total': len(self.available) + len(self.in_use)
234-
}
237+
with self._lock:
238+
return {
239+
'available': len(self.available),
240+
'in_use': len(self.in_use),
241+
'total': len(self.available) + len(self.in_use)
242+
}
235243

236244
# Create a simple resource (just a counter)
237245
resource_counter = [0]
@@ -266,40 +274,44 @@ def worker(worker_id):
266274
def test_cache_with_expiry_thread_safety(self):
267275
"""Test thread-safe cache with expiry."""
268276

269-
@auto_thread_safe(['get', 'put', 'cleanup'])
270277
class ThreadSafeExpiryCache:
271278
def __init__(self, default_ttl=1.0):
272279
self.cache = {}
273280
self.timestamps = {}
274281
self.default_ttl = default_ttl
282+
self._lock = threading.RLock()
275283

276284
def get(self, key):
277-
if key in self.cache:
278-
if time.time() - self.timestamps[key] < self.default_ttl:
279-
return self.cache[key]
280-
else:
281-
# Expired
282-
del self.cache[key]
283-
del self.timestamps[key]
284-
return None
285+
with self._lock:
286+
if key in self.cache:
287+
if time.time() - self.timestamps[key] < self.default_ttl:
288+
return self.cache[key]
289+
else:
290+
# Expired
291+
del self.cache[key]
292+
del self.timestamps[key]
293+
return None
285294

286295
def put(self, key, value, ttl=None):
287-
self.cache[key] = value
288-
self.timestamps[key] = time.time()
296+
with self._lock:
297+
self.cache[key] = value
298+
self.timestamps[key] = time.time()
289299

290300
def cleanup(self):
291-
current_time = time.time()
292-
expired_keys = [
293-
key for key, timestamp in self.timestamps.items()
294-
if current_time - timestamp >= self.default_ttl
295-
]
296-
for key in expired_keys:
297-
del self.cache[key]
298-
del self.timestamps[key]
299-
return len(expired_keys)
301+
with self._lock:
302+
current_time = time.time()
303+
expired_keys = [
304+
key for key, timestamp in self.timestamps.items()
305+
if current_time - timestamp >= self.default_ttl
306+
]
307+
for key in expired_keys:
308+
del self.cache[key]
309+
del self.timestamps[key]
310+
return len(expired_keys)
300311

301312
def size(self):
302-
return len(self.cache)
313+
with self._lock:
314+
return len(self.cache)
303315

304316
cache = ThreadSafeExpiryCache(default_ttl=0.1)
305317

@@ -335,39 +347,43 @@ def cache_worker(worker_id):
335347
def test_event_bus_thread_safety(self):
336348
"""Test thread-safe event bus pattern."""
337349

338-
@auto_thread_safe(['subscribe', 'unsubscribe', 'publish'])
339350
class ThreadSafeEventBus:
340351
def __init__(self):
341352
self.subscribers = {}
342353
self.event_count = {}
354+
self._lock = threading.RLock()
343355

344356
def subscribe(self, event_type, callback):
345-
if event_type not in self.subscribers:
346-
self.subscribers[event_type] = []
347-
self.subscribers[event_type].append(callback)
357+
with self._lock:
358+
if event_type not in self.subscribers:
359+
self.subscribers[event_type] = []
360+
self.subscribers[event_type].append(callback)
348361

349362
def unsubscribe(self, event_type, callback):
350-
if event_type in self.subscribers:
351-
try:
352-
self.subscribers[event_type].remove(callback)
353-
except ValueError:
354-
pass
363+
with self._lock:
364+
if event_type in self.subscribers:
365+
try:
366+
self.subscribers[event_type].remove(callback)
367+
except ValueError:
368+
pass
355369

356370
def publish(self, event_type, data):
357-
self.event_count[event_type] = self.event_count.get(event_type, 0) + 1
358-
if event_type in self.subscribers:
359-
for callback in self.subscribers[event_type][:]: # Copy to avoid modification during iteration
360-
try:
361-
callback(data)
362-
except Exception:
363-
pass # Ignore callback errors
371+
with self._lock:
372+
self.event_count[event_type] = self.event_count.get(event_type, 0) + 1
373+
if event_type in self.subscribers:
374+
for callback in self.subscribers[event_type][:]: # Copy to avoid modification during iteration
375+
try:
376+
callback(data)
377+
except Exception:
378+
pass # Ignore callback errors
364379

365380
def get_stats(self):
366-
return {
367-
'subscriber_count': sum(len(subs) for subs in self.subscribers.values()),
368-
'event_types': len(self.subscribers),
369-
'events_published': dict(self.event_count)
370-
}
381+
with self._lock:
382+
return {
383+
'subscriber_count': sum(len(subs) for subs in self.subscribers.values()),
384+
'event_types': len(self.subscribers),
385+
'events_published': dict(self.event_count)
386+
}
371387

372388
event_bus = ThreadSafeEventBus()
373389
received_events = []
@@ -425,38 +441,43 @@ def subscriber(event_type):
425441
def test_weak_reference_cleanup_thread_safety(self):
426442
"""Test thread safety with weak references and cleanup."""
427443

428-
@auto_thread_safe(['register', 'cleanup', 'get_count'])
429444
class ThreadSafeWeakRegistry:
430445
def __init__(self):
431446
self.registry = {}
432447
self.cleanup_count = 0
448+
self._lock = threading.RLock()
433449

434450
def register(self, obj, name):
435451
def cleanup_callback(weak_ref):
436-
self.cleanup_count += 1
437-
if name in self.registry:
438-
del self.registry[name]
452+
with self._lock:
453+
self.cleanup_count += 1
454+
if name in self.registry:
455+
del self.registry[name]
439456

440-
weak_ref = weakref.ref(obj, cleanup_callback)
441-
self.registry[name] = weak_ref
457+
with self._lock:
458+
weak_ref = weakref.ref(obj, cleanup_callback)
459+
self.registry[name] = weak_ref
442460

443461
def cleanup(self):
444-
# Manual cleanup of dead references
445-
dead_refs = []
446-
for name, weak_ref in self.registry.items():
447-
if weak_ref() is None:
448-
dead_refs.append(name)
449-
450-
for name in dead_refs:
451-
del self.registry[name]
452-
453-
return len(dead_refs)
462+
with self._lock:
463+
# Manual cleanup of dead references
464+
dead_refs = []
465+
for name, weak_ref in self.registry.items():
466+
if weak_ref() is None:
467+
dead_refs.append(name)
468+
469+
for name in dead_refs:
470+
del self.registry[name]
471+
472+
return len(dead_refs)
454473

455474
def get_count(self):
456-
return len(self.registry)
475+
with self._lock:
476+
return len(self.registry)
457477

458478
def get_cleanup_count(self):
459-
return self.cleanup_count
479+
with self._lock:
480+
return self.cleanup_count
460481

461482
registry = ThreadSafeWeakRegistry()
462483

0 commit comments

Comments
 (0)