Skip to content

Commit 875c7f1

Browse files
authored
Add wait_for_block method (#2489)
Fixes the `AsyncSubstrateInterface._get_block_handler` method's `result_handler` and adds a `wait_for_block` method.
1 parent 8128c80 commit 875c7f1

File tree

2 files changed

+112
-21
lines changed

2 files changed

+112
-21
lines changed

bittensor/utils/async_substrate_interface.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import asyncio
8+
import inspect
89
import json
910
import random
1011
from collections import defaultdict
@@ -1171,14 +1172,14 @@ async def _get_block_handler(
11711172
include_author: bool = False,
11721173
header_only: bool = False,
11731174
finalized_only: bool = False,
1174-
subscription_handler: Optional[Callable] = None,
1175+
subscription_handler: Optional[Callable[[dict], Awaitable[Any]]] = None,
11751176
):
11761177
try:
11771178
await self.init_runtime(block_hash=block_hash)
11781179
except BlockNotFound:
11791180
return None
11801181

1181-
async def decode_block(block_data, block_data_hash=None):
1182+
async def decode_block(block_data, block_data_hash=None) -> dict[str, Any]:
11821183
if block_data:
11831184
if block_data_hash:
11841185
block_data["header"]["hash"] = block_data_hash
@@ -1193,12 +1194,12 @@ async def decode_block(block_data, block_data_hash=None):
11931194

11941195
if "extrinsics" in block_data:
11951196
for idx, extrinsic_data in enumerate(block_data["extrinsics"]):
1196-
extrinsic_decoder = extrinsic_cls(
1197-
data=ScaleBytes(extrinsic_data),
1198-
metadata=self.__metadata,
1199-
runtime_config=self.runtime_config,
1200-
)
12011197
try:
1198+
extrinsic_decoder = extrinsic_cls(
1199+
data=ScaleBytes(extrinsic_data),
1200+
metadata=self.__metadata,
1201+
runtime_config=self.runtime_config,
1202+
)
12021203
extrinsic_decoder.decode(check_remaining=True)
12031204
block_data["extrinsics"][idx] = extrinsic_decoder
12041205

@@ -1314,23 +1315,29 @@ async def decode_block(block_data, block_data_hash=None):
13141315
if callable(subscription_handler):
13151316
rpc_method_prefix = "Finalized" if finalized_only else "New"
13161317

1317-
async def result_handler(message, update_nr, subscription_id):
1318-
new_block = await decode_block({"header": message["params"]["result"]})
1318+
async def result_handler(
1319+
message: dict, subscription_id: str
1320+
) -> tuple[Any, bool]:
1321+
reached = False
1322+
subscription_result = None
1323+
if "params" in message:
1324+
new_block = await decode_block(
1325+
{"header": message["params"]["result"]}
1326+
)
13191327

1320-
subscription_result = subscription_handler(
1321-
new_block, update_nr, subscription_id
1322-
)
1328+
subscription_result = await subscription_handler(new_block)
13231329

1324-
if subscription_result is not None:
1325-
# Handler returned end result: unsubscribe from further updates
1326-
self._forgettable_task = asyncio.create_task(
1327-
self.rpc_request(
1328-
f"chain_unsubscribe{rpc_method_prefix}Heads",
1329-
[subscription_id],
1330+
if subscription_result is not None:
1331+
reached = True
1332+
# Handler returned end result: unsubscribe from further updates
1333+
self._forgettable_task = asyncio.create_task(
1334+
self.rpc_request(
1335+
f"chain_unsubscribe{rpc_method_prefix}Heads",
1336+
[subscription_id],
1337+
)
13301338
)
1331-
)
13321339

1333-
return subscription_result
1340+
return subscription_result, reached
13341341

13351342
result = await self._make_rpc_request(
13361343
[
@@ -1343,7 +1350,7 @@ async def result_handler(message, update_nr, subscription_id):
13431350
result_handler=result_handler,
13441351
)
13451352

1346-
return result
1353+
return result["_get_block_handler"][-1]
13471354

13481355
else:
13491356
if header_only:
@@ -2770,3 +2777,49 @@ async def close(self):
27702777
await self.ws.shutdown()
27712778
except AttributeError:
27722779
pass
2780+
2781+
async def wait_for_block(
2782+
self,
2783+
block: int,
2784+
result_handler: Callable[[dict], Awaitable[Any]],
2785+
task_return: bool = True,
2786+
) -> Union[asyncio.Task, Union[bool, Any]]:
2787+
"""
2788+
Executes the result_handler when the chain has reached the block specified.
2789+
2790+
Args:
2791+
block: block number
2792+
result_handler: coroutine executed upon reaching the block number. This can be basically anything, but
2793+
must accept one single arg, a dict with the block data; whether you use this data or not is entirely
2794+
up to you.
2795+
task_return: True to immediately return the result of wait_for_block as an asyncio Task, False to wait
2796+
for the block to be reached, and return the result of the result handler.
2797+
2798+
Returns:
2799+
Either an asyncio.Task (which contains the running subscription, and whose `result()` will contain the
2800+
return of the result_handler), or the result itself, depending on `task_return` flag.
2801+
Note that if your result_handler returns `None`, this method will return `True`, otherwise
2802+
the return will be the result of your result_handler.
2803+
"""
2804+
2805+
async def _handler(block_data: dict[str, Any]):
2806+
required_number = block
2807+
number = block_data["header"]["number"]
2808+
if number >= required_number:
2809+
return (
2810+
r if (r := await result_handler(block_data)) is not None else True
2811+
)
2812+
2813+
args = inspect.getfullargspec(result_handler).args
2814+
if len(args) != 1:
2815+
raise ValueError(
2816+
"result_handler must take exactly one arg: the dict block data."
2817+
)
2818+
2819+
co = self._get_block_handler(
2820+
self.last_block_hash, subscription_handler=_handler
2821+
)
2822+
if task_return is True:
2823+
return asyncio.create_task(co)
2824+
else:
2825+
return await co
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
import asyncio
3+
from bittensor.utils import async_substrate_interface
4+
from typing import Any
5+
6+
7+
@pytest.mark.asyncio
8+
async def test_wait_for_block_invalid_result_handler():
9+
chain_interface = async_substrate_interface.AsyncSubstrateInterface(
10+
"dummy_endpoint"
11+
)
12+
13+
with pytest.raises(ValueError):
14+
15+
async def dummy_handler(
16+
block_data: dict[str, Any], extra_arg
17+
): # extra argument
18+
return block_data.get("header", {}).get("number", -1) == 2
19+
20+
await chain_interface.wait_for_block(
21+
block=2, result_handler=dummy_handler, task_return=False
22+
)
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_wait_for_block_async_return():
27+
chain_interface = async_substrate_interface.AsyncSubstrateInterface(
28+
"dummy_endpoint"
29+
)
30+
31+
async def dummy_handler(block_data: dict[str, Any]) -> bool:
32+
return block_data.get("header", {}).get("number", -1) == 2
33+
34+
result = await chain_interface.wait_for_block(
35+
block=2, result_handler=dummy_handler, task_return=True
36+
)
37+
38+
assert isinstance(result, asyncio.Task)

0 commit comments

Comments
 (0)