Skip to content

Commit 8bb5653

Browse files
committed
WIP
1 parent 88b6357 commit 8bb5653

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ def __init__(
543543
self.max_connections = max_connections
544544
self.shutdown_timer = shutdown_timer
545545
self._received: dict[str, asyncio.Future] = {}
546+
self._received_subscriptions: dict[str, asyncio.Queue] = {}
546547
self._sending = asyncio.Queue()
547548
self._receiving_task = None # TODO rename, as this now does send/recv
548549
self._attempts = 0
@@ -673,7 +674,8 @@ async def _recv(self, recd: bytes) -> None:
673674
self._received[response["id"]].set_result(response)
674675
self._in_use_ids.remove(response["id"])
675676
elif "params" in response:
676-
self._received[response["params"]["subscription"]].set_result(response)
677+
sub_id = response["params"]["subscription"]
678+
await self._received_subscriptions[sub_id].put(response)
677679
else:
678680
raise KeyError(response)
679681

@@ -708,6 +710,9 @@ async def _start_sending(self, ws) -> Exception:
708710
self._received[i].cancel()
709711
return
710712

713+
async def add_subscription(self, subscription_id: str) -> None:
714+
self._received_subscriptions[subscription_id] = asyncio.Queue()
715+
711716
async def send(self, payload: dict) -> str:
712717
"""
713718
Sends a payload to the websocket connection.
@@ -729,7 +734,7 @@ async def send(self, payload: dict) -> str:
729734
await self._sending.put(to_send)
730735
return original_id
731736

732-
async def retrieve(self, item_id: int) -> Optional[dict]:
737+
async def retrieve(self, item_id: str) -> Optional[dict]:
733738
"""
734739
Retrieves a single item from received responses dict queue
735740
@@ -739,14 +744,20 @@ async def retrieve(self, item_id: int) -> Optional[dict]:
739744
Returns:
740745
retrieved item
741746
"""
742-
item: asyncio.Future = self._received.get(item_id)
743-
if item.done():
744-
self.max_subscriptions.release()
745-
del self._received[item_id]
746-
return item.result()
747+
item: Optional[asyncio.Future] = self._received.get(item_id)
748+
if item is not None:
749+
if item.done():
750+
self.max_subscriptions.release()
751+
del self._received[item_id]
752+
return item.result()
747753
else:
748-
await asyncio.sleep(0.1)
749-
return None
754+
try:
755+
return self._received_subscriptions[item_id].get_nowait()
756+
# TODO make sure to delete during unsubscribe
757+
except asyncio.QueueEmpty:
758+
pass
759+
await asyncio.sleep(0.1)
760+
return None
750761

751762

752763
class AsyncSubstrateInterface(SubstrateMixin):
@@ -2304,6 +2315,7 @@ async def _make_rpc_request(
23042315
item_id = request_manager.overwrite_request(
23052316
item_id, response["result"]
23062317
)
2318+
await ws.add_subscription(response["result"])
23072319
subscription_added = True
23082320
except KeyError:
23092321
raise SubstrateRequestException(str(response))
@@ -2347,12 +2359,13 @@ async def _make_rpc_request(
23472359
f"Retrying attempt {attempt + 1} of {self.max_retries}"
23482360
)
23492361
return await self._make_rpc_request(
2350-
payloads,
2351-
value_scale_type,
2352-
storage_item,
2353-
result_handler,
2354-
attempt + 1,
2355-
force_legacy_decode,
2362+
payloads=payloads,
2363+
value_scale_type=value_scale_type,
2364+
storage_item=storage_item,
2365+
result_handler=result_handler,
2366+
attempt=attempt + 1,
2367+
runtime=runtime,
2368+
force_legacy_decode=force_legacy_decode,
23562369
)
23572370

23582371
return request_manager.get_results()

async_substrate_interface/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,13 +377,13 @@ def __init__(self, payloads):
377377
self.responses = defaultdict(lambda: {"complete": False, "results": []})
378378
self.payloads_count = len(payloads)
379379

380-
def add_request(self, item_id: Union[int, str], request_id: Any):
380+
def add_request(self, item_id: str, request_id: Any):
381381
"""
382382
Adds an outgoing request to the responses map for later retrieval
383383
"""
384384
self.response_map[item_id] = request_id
385385

386-
def overwrite_request(self, item_id: int, request_id: Any):
386+
def overwrite_request(self, item_id: str, request_id: Any):
387387
"""
388388
Overwrites an existing request in the responses map with a new request_id. This is used
389389
for multipart responses that generate a subscription id we need to watch, rather than the initial
@@ -392,7 +392,7 @@ def overwrite_request(self, item_id: int, request_id: Any):
392392
self.response_map[request_id] = self.response_map.pop(item_id)
393393
return request_id
394394

395-
def add_response(self, item_id: int, response: dict, complete: bool):
395+
def add_response(self, item_id: str, response: dict, complete: bool):
396396
"""
397397
Maps a response to the request for later retrieval
398398
"""

0 commit comments

Comments
 (0)