diff --git a/CHANGELOG.md b/CHANGELOG.md index 6469d4d..f1cc35d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## 1.1.0 /2025-04-07 + +## What's Changed +* Fix: response is still missing for callback by @zyzniewski-reef in https://github.com/opentensor/async-substrate-interface/pull/90 +* Expose websockets exceptions by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/91 +* Improved Query Map Decodes by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/84 + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.0.9...v1.1.0 + ## 1.0.9 /2025-03-26 ## What's Changed diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 502b743..e8a95aa 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -56,6 +56,9 @@ ) from async_substrate_interface.utils.storage import StorageKey from async_substrate_interface.type_registry import _TYPE_REGISTRY +from async_substrate_interface.utils.decoding import ( + decode_query_map, +) if TYPE_CHECKING: from websockets.asyncio.client import ClientConnection @@ -898,7 +901,7 @@ async def decode_scale( else: return obj - async def load_runtime(self, runtime): + def load_runtime(self, runtime): self.runtime = runtime # Update type registry @@ -954,7 +957,7 @@ async def init_runtime( ) if self.runtime and runtime_version == self.runtime.runtime_version: - return + return self.runtime runtime = self.runtime_cache.retrieve(runtime_version=runtime_version) if not runtime: @@ -990,7 +993,7 @@ async def init_runtime( runtime_version=runtime_version, runtime=runtime ) - await self.load_runtime(runtime) + self.load_runtime(runtime) if self.ss58_format is None: # Check and apply runtime constants @@ -1000,6 +1003,7 @@ async def init_runtime( if ss58_prefix_constant: self.ss58_format = ss58_prefix_constant + return runtime async def create_storage_key( self, @@ -2892,12 +2896,11 @@ async def query_map( Returns: AsyncQueryMapResult object """ - hex_to_bytes_ = hex_to_bytes params = params or [] block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash) if block_hash: self.last_block_hash = block_hash - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) metadata_pallet = self.runtime.metadata.get_metadata_pallet(module) if not metadata_pallet: @@ -2952,19 +2955,6 @@ async def query_map( result = [] last_key = None - def concat_hash_len(key_hasher: str) -> int: - """ - Helper function to avoid if statements - """ - if key_hasher == "Blake2_128Concat": - return 16 - elif key_hasher == "Twox64Concat": - return 8 - elif key_hasher == "Identity": - return 0 - else: - raise ValueError("Unsupported hash type") - if len(result_keys) > 0: last_key = result_keys[-1] @@ -2975,51 +2965,17 @@ def concat_hash_len(key_hasher: str) -> int: if "error" in response: raise SubstrateRequestException(response["error"]["message"]) - for result_group in response["result"]: - for item in result_group["changes"]: - try: - # Determine type string - key_type_string = [] - for n in range(len(params), len(param_types)): - key_type_string.append( - f"[u8; {concat_hash_len(key_hashers[n])}]" - ) - key_type_string.append(param_types[n]) - - item_key_obj = await self.decode_scale( - type_string=f"({', '.join(key_type_string)})", - scale_bytes=bytes.fromhex(item[0][len(prefix) :]), - return_scale_obj=True, - ) - - # strip key_hashers to use as item key - if len(param_types) - len(params) == 1: - item_key = item_key_obj[1] - else: - item_key = tuple( - item_key_obj[key + 1] - for key in range(len(params), len(param_types) + 1, 2) - ) - - except Exception as _: - if not ignore_decoding_errors: - raise - item_key = None - - try: - item_bytes = hex_to_bytes_(item[1]) - - item_value = await self.decode_scale( - type_string=value_type, - scale_bytes=item_bytes, - return_scale_obj=True, - ) - except Exception as _: - if not ignore_decoding_errors: - raise - item_value = None - result.append([item_key, item_value]) + result = decode_query_map( + result_group["changes"], + prefix, + runtime, + param_types, + params, + value_type, + key_hashers, + ignore_decoding_errors, + ) return AsyncQueryMapResult( records=result, page_size=page_size, diff --git a/async_substrate_interface/errors.py b/async_substrate_interface/errors.py index 7f619ad..9de753b 100644 --- a/async_substrate_interface/errors.py +++ b/async_substrate_interface/errors.py @@ -1,3 +1,9 @@ +from websockets.exceptions import ConnectionClosed, InvalidHandshake + +ConnectionClosed = ConnectionClosed +InvalidHandshake = InvalidHandshake + + class SubstrateRequestException(Exception): pass diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index c2c9b3c..f463f2f 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -34,6 +34,7 @@ from async_substrate_interface.utils.decoding import ( _determine_if_old_runtime_call, _bt_decode_to_dict_or_list, + decode_query_map, ) from async_substrate_interface.utils.storage import StorageKey from async_substrate_interface.type_registry import _TYPE_REGISTRY @@ -525,7 +526,9 @@ def __enter__(self): return self def __del__(self): - self.close() + self.ws.close() + print("DELETING SUBSTATE") + # self.ws.protocol.fail(code=1006) # ABNORMAL_CLOSURE def initialize(self): """ @@ -703,7 +706,7 @@ def init_runtime( ) if self.runtime and runtime_version == self.runtime.runtime_version: - return + return self.runtime runtime = self.runtime_cache.retrieve(runtime_version=runtime_version) if not runtime: @@ -757,6 +760,7 @@ def init_runtime( if ss58_prefix_constant: self.ss58_format = ss58_prefix_constant + return runtime def create_storage_key( self, @@ -1626,7 +1630,7 @@ def _make_rpc_request( if item_id not in request_manager.responses or isinstance( result_handler, Callable ): - if response := _received.pop(item_id): + if response := _received.pop(item_id, None): if ( isinstance(result_handler, Callable) and not subscription_added @@ -2598,7 +2602,7 @@ def query_map( block_hash = self._get_current_block_hash(block_hash, reuse_block_hash) if block_hash: self.last_block_hash = block_hash - self.init_runtime(block_hash=block_hash) + runtime = self.init_runtime(block_hash=block_hash) metadata_pallet = self.runtime.metadata.get_metadata_pallet(module) if not metadata_pallet: @@ -2654,19 +2658,6 @@ def query_map( result = [] last_key = None - def concat_hash_len(key_hasher: str) -> int: - """ - Helper function to avoid if statements - """ - if key_hasher == "Blake2_128Concat": - return 16 - elif key_hasher == "Twox64Concat": - return 8 - elif key_hasher == "Identity": - return 0 - else: - raise ValueError("Unsupported hash type") - if len(result_keys) > 0: last_key = result_keys[-1] @@ -2679,49 +2670,16 @@ def concat_hash_len(key_hasher: str) -> int: raise SubstrateRequestException(response["error"]["message"]) for result_group in response["result"]: - for item in result_group["changes"]: - try: - # Determine type string - key_type_string = [] - for n in range(len(params), len(param_types)): - key_type_string.append( - f"[u8; {concat_hash_len(key_hashers[n])}]" - ) - key_type_string.append(param_types[n]) - - item_key_obj = self.decode_scale( - type_string=f"({', '.join(key_type_string)})", - scale_bytes=bytes.fromhex(item[0][len(prefix) :]), - return_scale_obj=True, - ) - - # strip key_hashers to use as item key - if len(param_types) - len(params) == 1: - item_key = item_key_obj[1] - else: - item_key = tuple( - item_key_obj[key + 1] - for key in range(len(params), len(param_types) + 1, 2) - ) - - except Exception as _: - if not ignore_decoding_errors: - raise - item_key = None - - try: - item_bytes = hex_to_bytes_(item[1]) - - item_value = self.decode_scale( - type_string=value_type, - scale_bytes=item_bytes, - return_scale_obj=True, - ) - except Exception as _: - if not ignore_decoding_errors: - raise - item_value = None - result.append([item_key, item_value]) + result = decode_query_map( + result_group["changes"], + prefix, + runtime, + param_types, + params, + value_type, + key_hashers, + ignore_decoding_errors, + ) return QueryMapResult( records=result, page_size=page_size, diff --git a/async_substrate_interface/utils/decoding.py b/async_substrate_interface/utils/decoding.py index f0ce439..6dc7f21 100644 --- a/async_substrate_interface/utils/decoding.py +++ b/async_substrate_interface/utils/decoding.py @@ -1,6 +1,14 @@ -from typing import Union +from typing import Union, TYPE_CHECKING -from bt_decode import AxonInfo, PrometheusInfo +from bt_decode import AxonInfo, PrometheusInfo, decode_list +from scalecodec import ss58_encode +from bittensor_wallet.utils import SS58_FORMAT + +from async_substrate_interface.utils import hex_to_bytes +from async_substrate_interface.types import ScaleObj + +if TYPE_CHECKING: + from async_substrate_interface.types import Runtime def _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value) -> bool: @@ -44,3 +52,85 @@ def _bt_decode_to_dict_or_list(obj) -> Union[dict, list[dict]]: else: as_dict[key] = val return as_dict + + +def _decode_scale_list_with_runtime( + type_strings: list[str], + scale_bytes_list: list[bytes], + runtime_registry, + return_scale_obj: bool = False, +): + obj = decode_list(type_strings, runtime_registry, scale_bytes_list) + if return_scale_obj: + return [ScaleObj(x) for x in obj] + else: + return obj + + +def decode_query_map( + result_group_changes, + prefix, + runtime: "Runtime", + param_types, + params, + value_type, + key_hashers, + ignore_decoding_errors, +): + def concat_hash_len(key_hasher: str) -> int: + """ + Helper function to avoid if statements + """ + if key_hasher == "Blake2_128Concat": + return 16 + elif key_hasher == "Twox64Concat": + return 8 + elif key_hasher == "Identity": + return 0 + else: + raise ValueError("Unsupported hash type") + + hex_to_bytes_ = hex_to_bytes + + result = [] + # Determine type string + key_type_string_ = [] + for n in range(len(params), len(param_types)): + key_type_string_.append(f"[u8; {concat_hash_len(key_hashers[n])}]") + key_type_string_.append(param_types[n]) + key_type_string = f"({', '.join(key_type_string_)})" + + pre_decoded_keys = [] + pre_decoded_key_types = [key_type_string] * len(result_group_changes) + pre_decoded_values = [] + pre_decoded_value_types = [value_type] * len(result_group_changes) + + for item in result_group_changes: + pre_decoded_keys.append(bytes.fromhex(item[0][len(prefix) :])) + pre_decoded_values.append(hex_to_bytes_(item[1])) + all_decoded = _decode_scale_list_with_runtime( + pre_decoded_key_types + pre_decoded_value_types, + pre_decoded_keys + pre_decoded_values, + runtime.registry, + ) + middl_index = len(all_decoded) // 2 + decoded_keys = all_decoded[:middl_index] + decoded_values = [ScaleObj(x) for x in all_decoded[middl_index:]] + for dk, dv in zip(decoded_keys, decoded_values): + try: + # strip key_hashers to use as item key + if len(param_types) - len(params) == 1: + item_key = dk[1] + else: + item_key = tuple( + dk[key + 1] for key in range(len(params), len(param_types) + 1, 2) + ) + + except Exception as _: + if not ignore_decoding_errors: + raise + item_key = None + + item_value = dv + result.append([item_key, item_value]) + return result diff --git a/pyproject.toml b/pyproject.toml index f80eb46..71f0fd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.0.9" +version = "1.1.0" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } @@ -10,7 +10,7 @@ dependencies = [ "wheel", "asyncstdlib~=3.13.0", "bittensor-wallet>=2.1.3", - "bt-decode==v0.5.0", + "bt-decode==v0.6.0", "scalecodec~=1.2.11", "websockets>=14.1", "xxhash" diff --git a/tests/test_old_new.py b/tests/test_old_new.py index 897c589..c0a85e4 100644 --- a/tests/test_old_new.py +++ b/tests/test_old_new.py @@ -1,13 +1,22 @@ +import asyncio import os +import time + import bittensor as bt +from bittensor.core.chain_data import decode_account_id +from bittensor.core.settings import SS58_FORMAT import pytest +import substrateinterface + +from async_substrate_interface.async_substrate import AsyncSubstrateInterface +from async_substrate_interface.sync_substrate import SubstrateInterface try: n = int(os.getenv("NUMBER_RUNS")) except TypeError: n = 3 - +FINNEY_ENTRYPOINT = "wss://entrypoint-finney.opentensor.ai:443" coldkey = "5HHHHHzgLnYRvnKkHd45cRUDMHXTSwx7MjUzxBrKbY4JfZWn" # dtao epoch is 4920350 @@ -49,3 +58,93 @@ def test_sync(): for i in range(n): s2 = st.get_stake_for_coldkey(coldkey, block=b_post + i) print(f"at block {b_post + i}: {s2}") + + +@pytest.mark.asyncio +async def test_query_map(): + async def async_gathering(): + async def exhaust(qmr): + r = [] + async for k, v in await qmr: + r.append((k, v)) + return r + + start = time.time() + async with AsyncSubstrateInterface( + FINNEY_ENTRYPOINT, ss58_format=SS58_FORMAT + ) as substrate: + block_hash = await substrate.get_chain_head() + tasks = [ + substrate.query_map( + "SubtensorModule", + "TaoDividendsPerSubnet", + [netuid], + block_hash=block_hash, + ) + for netuid in range(1, 51) + ] + tasks = [exhaust(task) for task in tasks] + print(time.time() - start) + results_dicts_list = [] + for future in asyncio.as_completed(tasks): + result = await future + results_dicts_list.extend( + [(decode_account_id(k), v.value) for k, v in result] + ) + + elapsed = time.time() - start + print(f"Async Time: {elapsed}") + + print("Async Results", len(results_dicts_list)) + return results_dicts_list, block_hash + + def sync_new_method(block_hash): + result_dicts_list = [] + start = time.time() + with SubstrateInterface( + FINNEY_ENTRYPOINT, ss58_format=SS58_FORMAT + ) as substrate: + for netuid in range(1, 51): + tao_divs = list( + substrate.query_map( + "SubtensorModule", + "TaoDividendsPerSubnet", + [netuid], + block_hash=block_hash, + ) + ) + tao_divs = [(decode_account_id(k), v.value) for k, v in tao_divs] + result_dicts_list.extend(tao_divs) + print("New Sync Time:", time.time() - start) + print("New Sync Results", len(result_dicts_list)) + return result_dicts_list + + def sync_old_method(block_hash): + results_dicts_list = [] + start = time.time() + substrate = substrateinterface.SubstrateInterface( + FINNEY_ENTRYPOINT, ss58_format=SS58_FORMAT + ) + for netuid in range(1, 51): + tao_divs = list( + substrate.query_map( + "SubtensorModule", + "TaoDividendsPerSubnet", + [netuid], + block_hash=block_hash, + ) + ) + tao_divs = [(k.value, v.value) for k, v in tao_divs] + results_dicts_list.extend(tao_divs) + substrate.close() + print("Legacy Sync Time:", time.time() - start) + print("Legacy Sync Results", len(results_dicts_list)) + return results_dicts_list + + async_, block_hash_ = await async_gathering() + new_sync_ = sync_new_method(block_hash_) + legacy_sync = sync_old_method(block_hash_) + for k_v in async_: + assert k_v in legacy_sync + for k_v in new_sync_: + assert k_v in legacy_sync