@@ -65,6 +65,14 @@ class EventRegistrationError(Exception):
6565class NotIdentifiedError (Exception ):
6666 pass
6767
68+ async def _wait_cond (cond ):
69+ async with cond :
70+ await cond .wait ()
71+
72+ async def _wait_for_cond (cond , func ):
73+ async with cond :
74+ await cond .wait_for (func )
75+
6876class WebSocketClient :
6977 def __init__ (self ,
7078 url : str = "ws://localhost:4444" ,
@@ -92,7 +100,7 @@ async def connect(self):
92100 self .recv_task = None
93101 self .identified = False
94102 self .hello_message = None
95- self .ws = await websockets .connect (self .url , max_size = 2 ** 23 )
103+ self .ws = await websockets .connect (self .url , max_size = 2 ** 24 )
96104 self .recv_task = self .loop .create_task (self ._ws_recv_task ())
97105 return True
98106
@@ -101,12 +109,9 @@ async def wait_until_identified(self, timeout: int = 10):
101109 log .debug ('WebSocket session is not open. Returning early.' )
102110 return False
103111 try :
104- async with self .cond :
105- await asyncio .wait_for (self .cond .wait_for (self .is_identified ), timeout = timeout )
112+ await asyncio .wait_for (_wait_for_cond (self .cond , self .is_identified ), timeout = timeout )
106113 return True
107114 except asyncio .TimeoutError :
108- #if not self.ws.open:
109- # log.debug('WebSocket session is no longer open. Returning early.')
110115 return False
111116
112117
@@ -141,8 +146,7 @@ async def call(self, request: Request, timeout: int = 15):
141146 try :
142147 self .waiters [request_id ] = waiter
143148 await self .ws .send (json .dumps (request_payload ))
144- async with waiter .cond :
145- await asyncio .wait_for (waiter .cond .wait (), timeout = timeout )
149+ await asyncio .wait_for (_wait_cond (waiter .cond ), timeout = timeout )
146150 except asyncio .TimeoutError :
147151 raise MessageTimeout ('The request with type {} timed out after {} seconds.' .format (request .requestType , timeout ))
148152 finally :
@@ -191,8 +195,7 @@ async def call_batch(self, requests: list, timeout: int = 15, halt_on_failure: b
191195 try :
192196 self .waiters [request_batch_id ] = waiter
193197 await self .ws .send (json .dumps (request_batch_payload ))
194- async with waiter .cond :
195- await asyncio .wait_for (waiter .cond .wait (), timeout = timeout )
198+ await asyncio .wait_for (_wait_cond (waiter .cond ), timeout = timeout )
196199 except asyncio .TimeoutError :
197200 raise MessageTimeout ('The request batch timed out after {} seconds.' .format (timeout ))
198201 finally :
@@ -272,7 +275,7 @@ async def _ws_recv_task(self):
272275 while self .ws .open :
273276 message = ''
274277 try :
275- message = await asyncio . wait_for ( self .ws .recv (), timeout = 5 )
278+ message = await self .ws .recv ()
276279 if not message :
277280 continue
278281 incoming_payload = json .loads (message )
@@ -318,7 +321,4 @@ async def _ws_recv_task(self):
318321 break
319322 except json .JSONDecodeError :
320323 continue
321- except asyncio .TimeoutError :
322- async with self .cond :
323- self .cond .notify_all ()
324324 self .identified = False
0 commit comments