1
+ from collections .abc import AsyncGenerator
2
+ from contextlib import asynccontextmanager
1
3
from types import (
2
4
TracebackType ,
3
5
)
@@ -36,13 +38,21 @@ class ReadWriteLock:
36
38
def __init__ (self ) -> None :
37
39
self ._readers = 0
38
40
self ._readers_lock = trio .Lock () # Protects readers count
39
- self ._writer_lock = trio .Semaphore (1 ) # Acts like a task-transferable lock
41
+ self ._writer_lock = trio .Semaphore (1 ) # Ensures mutual exclusion for writers
40
42
41
43
async def acquire_read (self ) -> None :
42
- async with self ._readers_lock :
43
- self ._readers += 1
44
- if self ._readers == 1 :
45
- await self ._writer_lock .acquire ()
44
+ try :
45
+ async with self ._readers_lock :
46
+ self ._readers += 1
47
+ if self ._readers == 1 :
48
+ await self ._writer_lock .acquire ()
49
+ except trio .Cancelled :
50
+ async with self ._readers_lock :
51
+ if self ._readers > 0 :
52
+ self ._readers -= 1
53
+ if self ._readers == 0 :
54
+ self ._writer_lock .release ()
55
+ raise
46
56
47
57
async def release_read (self ) -> None :
48
58
async with self ._readers_lock :
@@ -51,11 +61,30 @@ async def release_read(self) -> None:
51
61
self ._writer_lock .release ()
52
62
53
63
async def acquire_write (self ) -> None :
54
- await self ._writer_lock .acquire ()
64
+ try :
65
+ await self ._writer_lock .acquire ()
66
+ except trio .Cancelled :
67
+ raise
55
68
56
69
def release_write (self ) -> None :
57
70
self ._writer_lock .release ()
58
71
72
+ @asynccontextmanager
73
+ async def read_lock (self ) -> AsyncGenerator [None , None ]:
74
+ await self .acquire_read ()
75
+ try :
76
+ yield
77
+ finally :
78
+ await self .release_read ()
79
+
80
+ @asynccontextmanager
81
+ async def write_lock (self ) -> AsyncGenerator [None , None ]:
82
+ await self .acquire_write ()
83
+ try :
84
+ yield
85
+ finally :
86
+ self .release_write ()
87
+
59
88
60
89
class MplexStream (IMuxedStream ):
61
90
"""
0 commit comments