@@ -18,7 +18,6 @@ class TestAdvancedThreadSafetyPatterns:
1818 def test_producer_consumer_pattern (self ):
1919 """Test thread safety in producer-consumer pattern."""
2020
21- @auto_thread_safe (['put' , 'get' ])
2221 class ThreadSafeQueue :
2322 def __init__ (self , maxsize = 10 ):
2423 self .queue = []
@@ -28,20 +27,25 @@ def __init__(self, maxsize=10):
2827 def put (self , item ):
2928 with self .condition :
3029 while len (self .queue ) >= self .maxsize :
31- self .condition .wait ()
30+ self .condition .wait (timeout = 1.0 ) # Add timeout to prevent infinite wait
31+ if len (self .queue ) >= self .maxsize :
32+ raise TimeoutError ("Queue full timeout" )
3233 self .queue .append (item )
3334 self .condition .notify_all ()
3435
3536 def get (self ):
3637 with self .condition :
3738 while not self .queue :
38- self .condition .wait ()
39+ self .condition .wait (timeout = 1.0 ) # Add timeout to prevent infinite wait
40+ if not self .queue :
41+ raise TimeoutError ("Queue empty timeout" )
3942 item = self .queue .pop (0 )
4043 self .condition .notify_all ()
4144 return item
4245
4346 def size (self ):
44- return len (self .queue )
47+ with self .condition :
48+ return len (self .queue )
4549
4650 queue = ThreadSafeQueue (maxsize = 5 )
4751 results = []
@@ -84,29 +88,32 @@ def consumer(count):
8488 def test_reader_writer_pattern (self ):
8589 """Test thread safety in reader-writer pattern."""
8690
87- @auto_thread_safe (['read' , 'write' ])
8891 class ThreadSafeDataStore :
8992 def __init__ (self ):
9093 self .data = {}
9194 self .read_count = 0
9295 self .write_count = 0
96+ self ._lock = threading .RLock ()
9397
9498 def read (self , key ):
95- self .read_count += 1
96- time .sleep (0.001 ) # Simulate read time
97- return self .data .get (key )
99+ with self ._lock :
100+ self .read_count += 1
101+ time .sleep (0.001 ) # Simulate read time
102+ return self .data .get (key )
98103
99104 def write (self , key , value ):
100- self .write_count += 1
101- time .sleep (0.001 ) # Simulate write time
102- self .data [key ] = value
105+ with self ._lock :
106+ self .write_count += 1
107+ time .sleep (0.001 ) # Simulate write time
108+ self .data [key ] = value
103109
104110 def get_stats (self ):
105- return {
106- 'reads' : self .read_count ,
107- 'writes' : self .write_count ,
108- 'data_size' : len (self .data )
109- }
111+ with self ._lock :
112+ return {
113+ 'reads' : self .read_count ,
114+ 'writes' : self .write_count ,
115+ 'data_size' : len (self .data )
116+ }
110117
111118 store = ThreadSafeDataStore ()
112119
@@ -194,14 +201,14 @@ def get_singleton():
194201 def test_resource_pool_pattern (self ):
195202 """Test thread-safe resource pool pattern."""
196203
197- @auto_thread_safe (['get_resource' , 'return_resource' ])
198204 class ThreadSafeResourcePool :
199205 def __init__ (self , create_resource_func , pool_size = 5 ):
200206 self .create_resource = create_resource_func
201207 self .pool_size = pool_size
202208 self .available = []
203209 self .in_use = set ()
204210 self .condition = threading .Condition ()
211+ self ._lock = threading .RLock ()
205212
206213 # Pre-populate pool
207214 for _ in range (pool_size ):
@@ -227,11 +234,12 @@ def return_resource(self, resource):
227234 self .condition .notify ()
228235
229236 def stats (self ):
230- return {
231- 'available' : len (self .available ),
232- 'in_use' : len (self .in_use ),
233- 'total' : len (self .available ) + len (self .in_use )
234- }
237+ with self ._lock :
238+ return {
239+ 'available' : len (self .available ),
240+ 'in_use' : len (self .in_use ),
241+ 'total' : len (self .available ) + len (self .in_use )
242+ }
235243
236244 # Create a simple resource (just a counter)
237245 resource_counter = [0 ]
@@ -266,40 +274,44 @@ def worker(worker_id):
266274 def test_cache_with_expiry_thread_safety (self ):
267275 """Test thread-safe cache with expiry."""
268276
269- @auto_thread_safe (['get' , 'put' , 'cleanup' ])
270277 class ThreadSafeExpiryCache :
271278 def __init__ (self , default_ttl = 1.0 ):
272279 self .cache = {}
273280 self .timestamps = {}
274281 self .default_ttl = default_ttl
282+ self ._lock = threading .RLock ()
275283
276284 def get (self , key ):
277- if key in self .cache :
278- if time .time () - self .timestamps [key ] < self .default_ttl :
279- return self .cache [key ]
280- else :
281- # Expired
282- del self .cache [key ]
283- del self .timestamps [key ]
284- return None
285+ with self ._lock :
286+ if key in self .cache :
287+ if time .time () - self .timestamps [key ] < self .default_ttl :
288+ return self .cache [key ]
289+ else :
290+ # Expired
291+ del self .cache [key ]
292+ del self .timestamps [key ]
293+ return None
285294
286295 def put (self , key , value , ttl = None ):
287- self .cache [key ] = value
288- self .timestamps [key ] = time .time ()
296+ with self ._lock :
297+ self .cache [key ] = value
298+ self .timestamps [key ] = time .time ()
289299
290300 def cleanup (self ):
291- current_time = time .time ()
292- expired_keys = [
293- key for key , timestamp in self .timestamps .items ()
294- if current_time - timestamp >= self .default_ttl
295- ]
296- for key in expired_keys :
297- del self .cache [key ]
298- del self .timestamps [key ]
299- return len (expired_keys )
301+ with self ._lock :
302+ current_time = time .time ()
303+ expired_keys = [
304+ key for key , timestamp in self .timestamps .items ()
305+ if current_time - timestamp >= self .default_ttl
306+ ]
307+ for key in expired_keys :
308+ del self .cache [key ]
309+ del self .timestamps [key ]
310+ return len (expired_keys )
300311
301312 def size (self ):
302- return len (self .cache )
313+ with self ._lock :
314+ return len (self .cache )
303315
304316 cache = ThreadSafeExpiryCache (default_ttl = 0.1 )
305317
@@ -335,39 +347,43 @@ def cache_worker(worker_id):
335347 def test_event_bus_thread_safety (self ):
336348 """Test thread-safe event bus pattern."""
337349
338- @auto_thread_safe (['subscribe' , 'unsubscribe' , 'publish' ])
339350 class ThreadSafeEventBus :
340351 def __init__ (self ):
341352 self .subscribers = {}
342353 self .event_count = {}
354+ self ._lock = threading .RLock ()
343355
344356 def subscribe (self , event_type , callback ):
345- if event_type not in self .subscribers :
346- self .subscribers [event_type ] = []
347- self .subscribers [event_type ].append (callback )
357+ with self ._lock :
358+ if event_type not in self .subscribers :
359+ self .subscribers [event_type ] = []
360+ self .subscribers [event_type ].append (callback )
348361
349362 def unsubscribe (self , event_type , callback ):
350- if event_type in self .subscribers :
351- try :
352- self .subscribers [event_type ].remove (callback )
353- except ValueError :
354- pass
363+ with self ._lock :
364+ if event_type in self .subscribers :
365+ try :
366+ self .subscribers [event_type ].remove (callback )
367+ except ValueError :
368+ pass
355369
356370 def publish (self , event_type , data ):
357- self .event_count [event_type ] = self .event_count .get (event_type , 0 ) + 1
358- if event_type in self .subscribers :
359- for callback in self .subscribers [event_type ][:]: # Copy to avoid modification during iteration
360- try :
361- callback (data )
362- except Exception :
363- pass # Ignore callback errors
371+ with self ._lock :
372+ self .event_count [event_type ] = self .event_count .get (event_type , 0 ) + 1
373+ if event_type in self .subscribers :
374+ for callback in self .subscribers [event_type ][:]: # Copy to avoid modification during iteration
375+ try :
376+ callback (data )
377+ except Exception :
378+ pass # Ignore callback errors
364379
365380 def get_stats (self ):
366- return {
367- 'subscriber_count' : sum (len (subs ) for subs in self .subscribers .values ()),
368- 'event_types' : len (self .subscribers ),
369- 'events_published' : dict (self .event_count )
370- }
381+ with self ._lock :
382+ return {
383+ 'subscriber_count' : sum (len (subs ) for subs in self .subscribers .values ()),
384+ 'event_types' : len (self .subscribers ),
385+ 'events_published' : dict (self .event_count )
386+ }
371387
372388 event_bus = ThreadSafeEventBus ()
373389 received_events = []
@@ -425,38 +441,43 @@ def subscriber(event_type):
425441 def test_weak_reference_cleanup_thread_safety (self ):
426442 """Test thread safety with weak references and cleanup."""
427443
428- @auto_thread_safe (['register' , 'cleanup' , 'get_count' ])
429444 class ThreadSafeWeakRegistry :
430445 def __init__ (self ):
431446 self .registry = {}
432447 self .cleanup_count = 0
448+ self ._lock = threading .RLock ()
433449
434450 def register (self , obj , name ):
435451 def cleanup_callback (weak_ref ):
436- self .cleanup_count += 1
437- if name in self .registry :
438- del self .registry [name ]
452+ with self ._lock :
453+ self .cleanup_count += 1
454+ if name in self .registry :
455+ del self .registry [name ]
439456
440- weak_ref = weakref .ref (obj , cleanup_callback )
441- self .registry [name ] = weak_ref
457+ with self ._lock :
458+ weak_ref = weakref .ref (obj , cleanup_callback )
459+ self .registry [name ] = weak_ref
442460
443461 def cleanup (self ):
444- # Manual cleanup of dead references
445- dead_refs = []
446- for name , weak_ref in self .registry .items ():
447- if weak_ref () is None :
448- dead_refs .append (name )
449-
450- for name in dead_refs :
451- del self .registry [name ]
452-
453- return len (dead_refs )
462+ with self ._lock :
463+ # Manual cleanup of dead references
464+ dead_refs = []
465+ for name , weak_ref in self .registry .items ():
466+ if weak_ref () is None :
467+ dead_refs .append (name )
468+
469+ for name in dead_refs :
470+ del self .registry [name ]
471+
472+ return len (dead_refs )
454473
455474 def get_count (self ):
456- return len (self .registry )
475+ with self ._lock :
476+ return len (self .registry )
457477
458478 def get_cleanup_count (self ):
459- return self .cleanup_count
479+ with self ._lock :
480+ return self .cleanup_count
460481
461482 registry = ThreadSafeWeakRegistry ()
462483
0 commit comments