Skip to content

Commit 14faed1

Browse files
committed
WIP check-in
1 parent 54026f6 commit 14faed1

File tree

1 file changed

+42
-40
lines changed

1 file changed

+42
-40
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def __init__(
540540
self.max_subscriptions = asyncio.Semaphore(max_subscriptions)
541541
self.max_connections = max_connections
542542
self.shutdown_timer = shutdown_timer
543-
self._received = {}
543+
self._received: dict[str, asyncio.Future] = {}
544544
self._sending = asyncio.Queue()
545545
self._receiving_task = None # TODO rename, as this now does send/recv
546546
self._attempts = 0
@@ -601,22 +601,23 @@ async def connect(self, force=False):
601601
now = await self.loop_time()
602602
self.last_received = now
603603
self.last_sent = now
604-
if self._exit_task:
605-
self._exit_task.cancel()
606-
if self.state != State.CLOSING:
607-
if not self._initialized or force:
608-
try:
609-
await asyncio.wait_for(self._cancel(), timeout=10.0)
610-
except asyncio.TimeoutError:
611-
pass
612-
self.ws = await asyncio.wait_for(
613-
connect(self.ws_url, **self._options), timeout=10.0
614-
)
615-
if self._receiving_task is None or self._receiving_task.done():
616-
self._receiving_task = asyncio.get_running_loop().create_task(
617-
self._handler(self.ws)
604+
async with self._lock:
605+
if self._exit_task:
606+
self._exit_task.cancel()
607+
if self.state not in (State.OPEN, State.CONNECTING):
608+
if not self._initialized or force:
609+
try:
610+
await asyncio.wait_for(self._cancel(), timeout=10.0)
611+
except asyncio.TimeoutError:
612+
pass
613+
self.ws = await asyncio.wait_for(
614+
connect(self.ws_url, **self._options), timeout=10.0
618615
)
619-
self._initialized = True
616+
if self._receiving_task is None or self._receiving_task.done():
617+
self._receiving_task = asyncio.get_running_loop().create_task(
618+
self._handler(self.ws)
619+
)
620+
self._initialized = True
620621

621622
async def _handler(self, ws: ClientConnection):
622623
consumer_task = asyncio.create_task(self._start_receiving(ws))
@@ -669,10 +670,10 @@ async def _recv(self, recd) -> None:
669670
response = json.loads(recd)
670671
self.last_received = await self.loop_time()
671672
if "id" in response:
672-
self._received[response["id"]] = response
673+
self._received[response["id"]].set_result(response)
673674
self._in_use_ids.remove(response["id"])
674675
elif "params" in response:
675-
self._received[response["params"]["subscription"]] = response
676+
self._received[response["params"]["subscription"]].set_result(response)
676677
else:
677678
raise KeyError(response)
678679
except ssl.SSLError:
@@ -685,19 +686,26 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
685686
async for recd in ws:
686687
await self._recv(recd)
687688
except Exception as e:
688-
return e
689+
for i in self._received.keys():
690+
self._received[i].set_exception(e)
691+
return
689692

690693
async def _start_sending(self, ws) -> Exception:
694+
to_send = None
691695
try:
692696
while True:
693-
logger.info("699 Not Empty")
694697
to_send = await self._sending.get()
695698
if self._log_raw_websockets:
696-
raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}")
699+
raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}}")
697700
await ws.send(json.dumps(to_send))
698701
self.last_sent = await self.loop_time()
699702
except Exception as e:
700-
return e
703+
if to_send is not None:
704+
self._received[to_send["id"]].set_exception(e)
705+
else:
706+
for i in self._received.keys():
707+
self._received[i].set_exception(e)
708+
return
701709

702710
async def send(self, payload: dict) -> str:
703711
"""
@@ -715,8 +723,8 @@ async def send(self, payload: dict) -> str:
715723
while original_id in self._in_use_ids:
716724
original_id = get_next_id()
717725
self._in_use_ids.add(original_id)
726+
self._received[original_id] = asyncio.get_running_loop().create_future()
718727
to_send = {**payload, **{"id": original_id}}
719-
logger.info(f"Sending {to_send}")
720728
await self._sending.put(to_send)
721729
return original_id
722730

@@ -730,11 +738,12 @@ async def retrieve(self, item_id: int) -> Optional[dict]:
730738
Returns:
731739
retrieved item
732740
"""
733-
try:
734-
item = self._received.pop(item_id)
741+
item: asyncio.Future = self._received.get(item_id)
742+
if item.done():
735743
self.max_subscriptions.release()
736-
return item
737-
except KeyError:
744+
del self._received[item_id]
745+
return item.result()
746+
else:
738747
await asyncio.sleep(0.1)
739748
return None
740749

@@ -2263,16 +2272,9 @@ async def _make_rpc_request(
22632272
subscription_added = False
22642273

22652274
async with self.ws as ws:
2266-
if len(payloads) > 1:
2267-
send_coroutines = await asyncio.gather(
2268-
*[ws.send(item["payload"]) for item in payloads]
2269-
)
2270-
for item_id, item in zip(send_coroutines, payloads):
2271-
request_manager.add_request(item_id, item["id"])
2272-
else:
2273-
item = payloads[0]
2274-
item_id = await ws.send(item["payload"])
2275-
request_manager.add_request(item_id, item["id"])
2275+
for payload in payloads:
2276+
item_id = await ws.send(payload)
2277+
request_manager.add_request(item_id, payload["id"])
22762278

22772279
while True:
22782280
for item_id in list(request_manager.response_map.keys()):
@@ -2311,17 +2313,17 @@ async def _make_rpc_request(
23112313
if request_manager.is_complete:
23122314
break
23132315
if (
2314-
(current_time := await self.ws.loop_time()) - self.ws.last_received
2316+
(current_time := await ws.loop_time()) - ws.last_received
23152317
>= self.retry_timeout
2316-
and current_time - self.ws.last_sent >= self.retry_timeout
2318+
and current_time - ws.last_sent >= self.retry_timeout
23172319
):
23182320
if attempt >= self.max_retries:
23192321
logger.error(
23202322
f"Timed out waiting for RPC requests {attempt} times. Exiting."
23212323
)
23222324
raise MaxRetriesExceeded("Max retries reached.")
23232325
else:
2324-
self.ws.last_received = time.time()
2326+
self.ws.last_received = await ws.loop_time()
23252327
await self.ws.connect(force=True)
23262328
logger.warning(
23272329
f"Timed out waiting for RPC requests. "

0 commit comments

Comments
 (0)