3535
3636
3737class ReadWriteLock :
38+ """
39+ A read-write lock that allows multiple concurrent readers
40+ or one exclusive writer, implemented using Trio primitives.
41+ """
42+
3843 def __init__ (self ) -> None :
3944 self ._readers = 0
40- self ._readers_lock = trio .Lock () # Protects readers count
41- self ._writer_lock = trio .Semaphore (1 ) # Ensures mutual exclusion for writers
45+ self ._readers_lock = trio .Lock () # Protects access to _readers count
46+ self ._writer_lock = trio .Semaphore (1 ) # Allows only one writer at a time
4247
4348 async def acquire_read (self ) -> None :
49+ """Acquire a read lock. Multiple readers can hold it simultaneously."""
4450 try :
4551 async with self ._readers_lock :
46- self ._readers += 1
47- if self ._readers == 1 :
52+ if self ._readers == 0 :
4853 await self ._writer_lock .acquire ()
54+ self ._readers += 1
4955 except trio .Cancelled :
50- async with self ._readers_lock :
51- if self ._readers > 0 :
52- self ._readers -= 1
53- if self ._readers == 0 :
54- self ._writer_lock .release ()
5556 raise
5657
5758 async def release_read (self ) -> None :
59+ """Release a read lock."""
5860 async with self ._readers_lock :
59- self ._readers -= 1
60- if self ._readers == 0 :
61+ if self ._readers == 1 :
6162 self ._writer_lock .release ()
63+ self ._readers -= 1
6264
6365 async def acquire_write (self ) -> None :
66+ """Acquire an exclusive write lock."""
6467 try :
6568 await self ._writer_lock .acquire ()
6669 except trio .Cancelled :
6770 raise
6871
6972 def release_write (self ) -> None :
73+ """Release the exclusive write lock."""
7074 self ._writer_lock .release ()
7175
7276 @asynccontextmanager
7377 async def read_lock (self ) -> AsyncGenerator [None , None ]:
74- await self .acquire_read ()
78+ """Context manager for acquiring and releasing a read lock safely."""
79+ acquire = False
7580 try :
81+ await self .acquire_read ()
82+ acquire = True
7683 yield
7784 finally :
78- await self .release_read ()
85+ if acquire :
86+ with trio .CancelScope () as scope :
87+ scope .shield = True
88+ await self .release_read ()
7989
8090 @asynccontextmanager
8191 async def write_lock (self ) -> AsyncGenerator [None , None ]:
82- await self .acquire_write ()
92+ """Context manager for acquiring and releasing a write lock safely."""
93+ acquire = False
8394 try :
95+ await self .acquire_write ()
96+ acquire = True
8497 yield
8598 finally :
86- self .release_write ()
99+ if acquire :
100+ self .release_write ()
87101
88102
89103class MplexStream (IMuxedStream ):
@@ -168,9 +182,7 @@ async def read(self, n: int | None = None) -> bytes:
168182 :param n: number of bytes to read
169183 :return: bytes actually read
170184 """
171- await self .rw_lock .acquire_read ()
172- payload : bytes = b""
173- try :
185+ async with self .rw_lock .read_lock ():
174186 if n is not None and n < 0 :
175187 raise ValueError (
176188 "the number of bytes to read n must be non-negative or "
@@ -210,21 +222,17 @@ async def read(self, n: int | None = None) -> bytes:
210222 "This should never happen."
211223 ) from error
212224 self ._buf .extend (self ._read_return_when_blocked ())
213- chunk = self ._buf [:n ]
214- self ._buf = self ._buf [len (chunk ) :]
215- payload = bytes (chunk )
216- finally :
217- await self .rw_lock .release_read ()
218- return payload
225+ payload = self ._buf [:n ]
226+ self ._buf = self ._buf [len (payload ) :]
227+ return bytes (payload )
219228
220229 async def write (self , data : bytes ) -> None :
221230 """
222231 Write to stream.
223232
224233 :return: number of bytes written
225234 """
226- await self .rw_lock .acquire_write ()
227- try :
235+ async with self .rw_lock .write_lock ():
228236 if self .event_local_closed .is_set ():
229237 raise MplexStreamClosed (f"cannot write to closed stream: data={ data !r} " )
230238 flag = (
@@ -233,8 +241,6 @@ async def write(self, data: bytes) -> None:
233241 else HeaderTags .MessageReceiver
234242 )
235243 await self .muxed_conn .send_message (flag , data , self .stream_id )
236- finally :
237- self .rw_lock .release_write ()
238244
239245 async def close (self ) -> None :
240246 """
0 commit comments