diff --git a/CHANGELOG.md b/CHANGELOG.md index 138c8a6..22379d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## 1.4.0 /2025-07-07 +* Removes unused imports by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/139 +* Improve CachedFetcher by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/140 +* Only use Runtime objects in AsyncSubstrateInterface by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/141 +* python ss58 conversion by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/143 +* fully exhaust query map by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/144 +* Only use v14 decoding for events by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/145 + + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.3.1...v1.4.0 + ## 1.3.1 /2025-06-11 * Fixes default vals for archive_nodes by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/134 * Adds ability to log raw websockets for debugging. by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/133 diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 598a882..fb86216 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -23,7 +23,7 @@ ) from bt_decode import MetadataV15, PortableRegistry, decode as decode_by_type_string -from scalecodec.base import ScaleBytes, ScaleType, RuntimeConfigurationObject +from scalecodec.base import ScaleBytes, ScaleType from scalecodec.types import ( GenericCall, GenericExtrinsic, @@ -34,13 +34,11 @@ from websockets.asyncio.client import connect from websockets.exceptions import ConnectionClosed, WebSocketException -from async_substrate_interface.const import SS58_FORMAT from async_substrate_interface.errors import ( SubstrateRequestException, ExtrinsicNotFound, BlockNotFound, MaxRetriesExceeded, - MetadataAtVersionNotFound, StateDiscardedError, ) from async_substrate_interface.protocols import Keypair @@ -58,10 +56,15 @@ get_next_id, rng as random, ) -from async_substrate_interface.utils.cache import async_sql_lru_cache, CachedFetcher +from async_substrate_interface.utils.cache import ( + async_sql_lru_cache, + cached_fetcher, +) from async_substrate_interface.utils.decoding import ( _determine_if_old_runtime_call, _bt_decode_to_dict_or_list, + legacy_scale_decode, + convert_account_ids, ) from async_substrate_interface.utils.storage import StorageKey from async_substrate_interface.type_registry import _TYPE_REGISTRY @@ -278,17 +281,26 @@ async def process_events(self): self.__weight = dispatch_info["weight"] if "Module" in dispatch_error: - module_index = dispatch_error["Module"][0]["index"] - error_index = int.from_bytes( - bytes(dispatch_error["Module"][0]["error"]), - byteorder="little", - signed=False, - ) + if isinstance(dispatch_error["Module"], tuple): + module_index = dispatch_error["Module"][0] + error_index = dispatch_error["Module"][1] + else: + module_index = dispatch_error["Module"]["index"] + error_index = dispatch_error["Module"]["error"] if isinstance(error_index, str): # Actual error index is first u8 in new [u8; 4] format error_index = int(error_index[2:4], 16) - module_error = self.substrate.metadata.get_module_error( + + if self.block_hash: + runtime = await self.substrate.init_runtime( + block_hash=self.block_hash + ) + else: + runtime = await self.substrate.init_runtime( + block_id=self.block_number + ) + module_error = runtime.metadata.get_module_error( module_index=module_index, error_index=error_index ) self.__error_message = { @@ -453,7 +465,6 @@ async def retrieve_next_page(self, start_key) -> list: ) if len(result.records) < self.page_size: self.loading_complete = True - # Update last key from new result set to use as offset for next page self.last_key = result.last_key return result.records @@ -731,6 +742,7 @@ def __init__( _mock: bool = False, _log_raw_websockets: bool = False, ws_shutdown_timer: float = 5.0, + decode_ss58: bool = False, ): """ The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to @@ -750,8 +762,16 @@ def __init__( _mock: whether to use mock version of the subtensor interface _log_raw_websockets: whether to log raw websocket requests during RPC requests ws_shutdown_timer: how long after the last connection your websocket should close + decode_ss58: Whether to decode AccountIds to SS58 or leave them in raw bytes tuples. """ + super().__init__( + type_registry, + type_registry_preset, + use_remote_preset, + ss58_format, + decode_ss58, + ) self.max_retries = max_retries self.retry_timeout = retry_timeout self.chain_endpoint = url @@ -780,26 +800,13 @@ def __init__( } self.initialized = False self._forgettable_task = None - self.ss58_format = ss58_format self.type_registry = type_registry self.type_registry_preset = type_registry_preset self.runtime_cache = RuntimeCache() - self.runtime_config = RuntimeConfigurationObject( - ss58_format=self.ss58_format, implements_scale_info=True - ) self._nonces = {} self.metadata_version_hex = "0x0f000000" # v15 - self.reload_type_registry() self._initializing = False - self.registry_type_map = {} - self.type_id_to_name = {} self._mock = _mock - self._block_hash_fetcher = CachedFetcher(512, self._get_block_hash) - self._parent_hash_fetcher = CachedFetcher(512, self._get_parent_block_hash) - self._runtime_info_fetcher = CachedFetcher(16, self._get_block_runtime_info) - self._runtime_version_for_fetcher = CachedFetcher( - 512, self._get_block_runtime_version_for - ) async def __aenter__(self): if not self._mock: @@ -815,13 +822,52 @@ async def initialize(self): if not self._chain: chain = await self.rpc_request("system_chain", []) self._chain = chain.get("result") - await self.init_runtime() + runtime = await self.init_runtime() + if self.ss58_format is None: + # Check and apply runtime constants + ss58_prefix_constant = await self.get_constant( + "System", "SS58Prefix", runtime=runtime + ) + + if ss58_prefix_constant: + self.ss58_format = ss58_prefix_constant.value + runtime.ss58_format = ss58_prefix_constant.value + runtime.runtime_config.ss58_format = ss58_prefix_constant.value self.initialized = True self._initializing = False async def __aexit__(self, exc_type, exc_val, exc_tb): pass + @property + def metadata(self): + warnings.warn( + "Calling AsyncSubstrateInterface.metadata is deprecated, as metadata is runtime-dependent, and it" + "can be unclear which for runtime you seek the metadata. You should instead use the specific runtime's " + "metadata. For now, the most recently used runtime will be given.", + category=DeprecationWarning, + ) + runtime = self.runtime_cache.last_used + if not runtime or runtime.metadata is None: + raise AttributeError( + "Metadata not found. This generally indicates that the AsyncSubstrateInterface object " + "is not properly async initialized." + ) + else: + return runtime.metadata + + @property + def implements_scaleinfo(self) -> Optional[bool]: + """ + Returns True if most-recently-used runtime implements a `PortableRegistry` (`MetadataV14` and higher). Returns + `None` if no runtime has been loaded. + """ + runtime = self.runtime_cache.last_used + if runtime is not None: + return runtime.implements_scaleinfo + else: + return None + @property async def properties(self): if self._properties is None: @@ -860,8 +906,8 @@ async def name(self): async def get_storage_item( self, module: str, storage_function: str, block_hash: str = None ): - await self.init_runtime(block_hash=block_hash) - metadata_pallet = self.runtime.metadata.get_metadata_pallet(module) + runtime = await self.init_runtime(block_hash=block_hash) + metadata_pallet = runtime.metadata.get_metadata_pallet(module) storage_item = metadata_pallet.get_storage_function(storage_function) return storage_item @@ -878,7 +924,7 @@ async def _get_current_block_hash( async def _load_registry_at_block( self, block_hash: Optional[str] - ) -> tuple[MetadataV15, PortableRegistry]: + ) -> tuple[Optional[MetadataV15], Optional[PortableRegistry]]: # Should be called for any block that fails decoding. # Possibly the metadata was different. try: @@ -892,59 +938,38 @@ async def _load_registry_at_block( "Client error: Execution failed: Other: Exported method Metadata_metadata_at_version is not found" in e.args ): - raise MetadataAtVersionNotFound + return None, None else: raise e metadata_option_hex_str = metadata_rpc_result["result"] metadata_option_bytes = bytes.fromhex(metadata_option_hex_str[2:]) metadata = MetadataV15.decode_from_metadata_option(metadata_option_bytes) registry = PortableRegistry.from_metadata_v15(metadata) - self._load_registry_type_map(registry) return metadata, registry - async def _wait_for_registry(self, _attempt: int = 1, _retries: int = 3) -> None: - async def _waiter(): - while self.runtime.registry is None: - await asyncio.sleep(0.1) - return - - try: - if not self.runtime.registry: - await asyncio.wait_for(_waiter(), timeout=10) - except TimeoutError: - # indicates that registry was never loaded - if not self._initializing: - raise AttributeError( - "Registry was never loaded. This did not occur during initialization, which usually indicates " - "you must first initialize the AsyncSubstrateInterface object, either with " - "`await AsyncSubstrateInterface.initialize()` or running with `async with`" - ) - elif _attempt < _retries: - await self._load_registry_at_block(None) - return await self._wait_for_registry(_attempt + 1, _retries) - else: - raise AttributeError( - "Registry was never loaded. This occurred during initialization, which usually indicates a " - "connection or node error." - ) - async def encode_scale( - self, type_string, value: Any, _attempt: int = 1, _retries: int = 3 + self, + type_string, + value: Any, + block_hash: Optional[str] = None, + runtime: Optional[Runtime] = None, ) -> bytes: """ - Helper function to encode arbitrary data into SCALE-bytes for given RUST type_string + Helper function to encode arbitrary data into SCALE-bytes for given RUST type_string. If neither `block_hash` + nor `runtime` are supplied, the runtime of the current block will be used. Args: type_string: the type string of the SCALE object for decoding value: value to encode - _attempt: the current number of attempts to load the registry needed to encode the value - _retries: the maximum number of attempts to load the registry needed to encode the value + block_hash: hash of the block where the desired runtime is located. Ignored if supplying `runtime` + runtime: the runtime to use for the scale encoding. If supplied, `block_hash` is ignored Returns: encoded bytes """ - await self._wait_for_registry(_attempt, _retries) - return self._encode_scale(type_string, value) + if runtime is None: + runtime = await self.init_runtime(block_hash=block_hash) + return self._encode_scale(type_string, value, runtime=runtime) async def decode_scale( self, @@ -953,6 +978,9 @@ async def decode_scale( _attempt=1, _retries=3, return_scale_obj: bool = False, + block_hash: Optional[str] = None, + runtime: Optional[Runtime] = None, + force_legacy: bool = False, ) -> Union[ScaleObj, Any]: """ Helper function to decode arbitrary SCALE-bytes (e.g. 0x02000000) according to given RUST type_string @@ -965,6 +993,10 @@ async def decode_scale( _attempt: the number of attempts to pull the registry before timing out _retries: the number of retries to pull the registry before timing out return_scale_obj: Whether to return the decoded value wrapped in a SCALE-object-like wrapper, or raw. + block_hash: Hash of the block where the desired runtime is located. Ignored if supplying `runtime` + runtime: Optional Runtime object whose registry to use for decoding. If not specified, runtime will be + loaded based on the block hash specified (or latest block if no block_hash is specified) + force_legacy: Whether to explicitly use legacy Python-only decoding (non bt-decode). Returns: Decoded object @@ -973,36 +1005,28 @@ async def decode_scale( return None if type_string == "scale_info::0": # Is an AccountId # Decode AccountId bytes to SS58 address - return ss58_encode(scale_bytes, SS58_FORMAT) + return ss58_encode(scale_bytes, self.ss58_format) else: - await self._wait_for_registry(_attempt, _retries) - obj = decode_by_type_string(type_string, self.runtime.registry, scale_bytes) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) + if runtime.metadata_v15 is not None and force_legacy is False: + obj = decode_by_type_string(type_string, runtime.registry, scale_bytes) + if self.decode_ss58: + try: + type_str_int = int(type_string.split("::")[1]) + decoded_type_str = runtime.type_id_to_name[type_str_int] + obj = convert_account_ids( + obj, decoded_type_str, runtime.ss58_format + ) + except (ValueError, KeyError): + pass + else: + obj = legacy_scale_decode(type_string, scale_bytes, runtime) if return_scale_obj: return ScaleObj(obj) else: return obj - def load_runtime(self, runtime): - self.runtime = runtime - - # Update type registry - self.reload_type_registry(use_remote_preset=False, auto_discover=True) - - self.runtime_config.set_active_spec_version_id(runtime.runtime_version) - if self.implements_scaleinfo: - logger.debug("Add PortableRegistry from metadata to type registry") - self.runtime_config.add_portable_registry(runtime.metadata) - # Set runtime compatibility flags - try: - _ = self.runtime_config.create_scale_object("sp_weights::weight_v2::Weight") - self.config["is_weight_v2"] = True - self.runtime_config.update_type_registry_types( - {"Weight": "sp_weights::weight_v2::Weight"} - ) - except NotImplementedError: - self.config["is_weight_v2"] = False - self.runtime_config.update_type_registry_types({"Weight": "WeightV1"}) - async def init_runtime( self, block_hash: Optional[str] = None, block_id: Optional[int] = None ) -> Runtime: @@ -1026,10 +1050,16 @@ async def init_runtime( raise ValueError("Cannot provide block_hash and block_id at the same time") if block_id is not None: + if runtime := self.runtime_cache.retrieve(block=block_id): + return runtime block_hash = await self.get_block_hash(block_id) if not block_hash: block_hash = await self.get_chain_head() + else: + self.last_block_hash = block_hash + if runtime := self.runtime_cache.retrieve(block_hash=block_hash): + return runtime runtime_version = await self.get_block_runtime_version_for(block_hash) if runtime_version is None: @@ -1037,53 +1067,78 @@ async def init_runtime( f"No runtime information for block '{block_hash}'" ) - if self.runtime and runtime_version == self.runtime.runtime_version: - return self.runtime - - runtime = self.runtime_cache.retrieve(runtime_version=runtime_version) - if not runtime: - self.last_block_hash = block_hash + if runtime := self.runtime_cache.retrieve(runtime_version=runtime_version): + return runtime + else: + return await self.get_runtime_for_version(runtime_version, block_hash) - runtime_block_hash = await self.get_parent_block_hash(block_hash) + @cached_fetcher(max_size=16, cache_key_index=0) + async def get_runtime_for_version( + self, runtime_version: int, block_hash: Optional[str] = None + ) -> Runtime: + """ + Retrieves the `Runtime` for a given runtime version at a given block hash. + Args: + runtime_version: version of the runtime (from `get_block_runtime_version_for`) + block_hash: hash of the block to query - runtime_info = await self.get_block_runtime_info(runtime_block_hash) + Returns: + Runtime object for the given runtime version + """ + return await self._get_runtime_for_version(runtime_version, block_hash) - metadata, (metadata_v15, registry) = await asyncio.gather( - self.get_block_metadata(block_hash=runtime_block_hash, decode=True), - self._load_registry_at_block(block_hash=runtime_block_hash), + async def _get_runtime_for_version( + self, runtime_version: int, block_hash: Optional[str] = None + ) -> Runtime: + if not block_hash: + block_hash, runtime_block_hash, block_number = await asyncio.gather( + self.get_chain_head(), + self.get_parent_block_hash(block_hash), + self.get_block_number(block_hash), ) - if metadata is None: - # does this ever happen? - raise SubstrateRequestException( - f"No metadata for block '{runtime_block_hash}'" - ) - logger.debug( - f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node" + else: + runtime_block_hash, block_number = await asyncio.gather( + self.get_parent_block_hash(block_hash), + self.get_block_number(block_hash), ) - - runtime = Runtime( - chain=self.chain, - runtime_config=self.runtime_config, - metadata=metadata, - type_registry=self.type_registry, - metadata_v15=metadata_v15, - runtime_info=runtime_info, - registry=registry, + runtime_info, metadata, (metadata_v15, registry) = await asyncio.gather( + self.get_block_runtime_info(runtime_block_hash), + self.get_block_metadata(block_hash=runtime_block_hash, decode=True), + self._load_registry_at_block(block_hash=runtime_block_hash), + ) + if metadata is None: + # does this ever happen? + raise SubstrateRequestException( + f"No metadata for block '{runtime_block_hash}'" ) - self.runtime_cache.add_item( - runtime_version=runtime_version, runtime=runtime + if metadata_v15 is not None: + logger.debug( + f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node" ) - - self.load_runtime(runtime) - - if self.ss58_format is None: - # Check and apply runtime constants - ss58_prefix_constant = await self.get_constant( - "System", "SS58Prefix", block_hash=block_hash + else: + logger.debug( + f"Exported method Metadata_metadata_at_version is not found for {runtime_version}. This indicates the " + f"block is quite old, decoding for this block will use legacy Python decoding." ) - - if ss58_prefix_constant: - self.ss58_format = ss58_prefix_constant + implements_scale_info = metadata.portable_registry is not None + runtime = Runtime( + chain=self.chain, + runtime_config=self._runtime_config_copy( + implements_scale_info=implements_scale_info + ), + metadata=metadata, + type_registry=self.type_registry, + metadata_v15=metadata_v15, + runtime_info=runtime_info, + registry=registry, + ss58_format=self.ss58_format, + ) + self.runtime_cache.add_item( + block=block_number, + block_hash=block_hash, + runtime_version=runtime_version, + runtime=runtime, + ) return runtime async def create_storage_key( @@ -1105,14 +1160,14 @@ async def create_storage_key( Returns: StorageKey """ - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) return StorageKey.create_from_storage_function( pallet, storage_function, params, - runtime_config=self.runtime_config, - metadata=self.runtime.metadata, + runtime_config=runtime.runtime_config, + metadata=runtime.metadata, ) async def subscribe_storage( @@ -1144,7 +1199,7 @@ async def subscription_handler(storage_key, obj, subscription_id): subscription_handler: coroutine function to handle value changes of subscription """ - await self.init_runtime() + runtime = await self.init_runtime() storage_key_map = {s.to_hex(): s for s in storage_keys} @@ -1188,6 +1243,7 @@ async def result_handler( updated_obj = await self.decode_scale( type_string=change_scale_type, scale_bytes=hex_to_bytes(change_data), + runtime=runtime, ) subscription_result = await subscription_handler( @@ -1239,36 +1295,45 @@ async def retrieve_pending_extrinsics(self) -> list: return extrinsics - async def get_metadata_storage_functions(self, block_hash=None) -> list: + async def get_metadata_storage_functions( + self, block_hash=None, runtime: Optional[Runtime] = None + ) -> list: """ Retrieves a list of all storage functions in metadata active at given block_hash (or chaintip if block_hash is omitted) Args: block_hash: hash of the blockchain block whose runtime to use + runtime: Optional `Runtime` whose metadata to use Returns: list of storage functions """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) storage_list = [] - for module_idx, module in enumerate(self.metadata.pallets): + for module_idx, module in enumerate(runtime.metadata.pallets): if module.storage: for storage in module.storage: storage_list.append( self.serialize_storage_item( storage_item=storage, module=module, - spec_version_id=self.runtime.runtime_version, + spec_version_id=runtime.runtime_version, + runtime=runtime, ) ) return storage_list async def get_metadata_storage_function( - self, module_name, storage_name, block_hash=None + self, + module_name, + storage_name, + block_hash=None, + runtime: Optional[Runtime] = None, ): """ Retrieves the details of a storage function for given module name, call function name and block_hash @@ -1277,47 +1342,57 @@ async def get_metadata_storage_function( module_name storage_name block_hash + runtime: Optional `Runtime` whose metadata to use Returns: Metadata storage function """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) - pallet = self.metadata.get_metadata_pallet(module_name) + pallet = runtime.metadata.get_metadata_pallet(module_name) if pallet: return pallet.get_storage_function(storage_name) async def get_metadata_errors( - self, block_hash=None + self, block_hash=None, runtime: Optional[Runtime] = None ) -> list[dict[str, Optional[str]]]: """ Retrieves a list of all errors in metadata active at given block_hash (or chaintip if block_hash is omitted) Args: block_hash: hash of the blockchain block whose metadata to use + runtime: Optional `Runtime` whose metadata to use Returns: list of errors in the metadata """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) error_list = [] - for module_idx, module in enumerate(self.runtime.metadata.pallets): + for module_idx, module in enumerate(runtime.metadata.pallets): if module.errors: for error in module.errors: error_list.append( self.serialize_module_error( module=module, error=error, - spec_version=self.runtime.runtime_version, + spec_version=runtime.runtime_version, ) ) return error_list - async def get_metadata_error(self, module_name, error_name, block_hash=None): + async def get_metadata_error( + self, + module_name: str, + error_name: str, + block_hash: Optional[str] = None, + runtime: Optional[Runtime] = None, + ): """ Retrieves the details of an error for given module name, call function name and block_hash @@ -1325,21 +1400,23 @@ async def get_metadata_error(self, module_name, error_name, block_hash=None): module_name: module name for the error lookup error_name: error name for the error lookup block_hash: hash of the blockchain block whose metadata to use + runtime: Optional `Runtime` whose metadata to use Returns: error """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) - for module_idx, module in enumerate(self.runtime.metadata.pallets): + for module_idx, module in enumerate(runtime.metadata.pallets): if module.name == module_name and module.errors: for error in module.errors: if error_name == error.name: return error async def get_metadata_runtime_call_functions( - self, block_hash: str = None + self, block_hash: str = None, runtime: Optional[Runtime] = None ) -> list[GenericRuntimeCallDefinition]: """ Get a list of available runtime API calls @@ -1347,83 +1424,61 @@ async def get_metadata_runtime_call_functions( Returns: list of runtime call functions """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) call_functions = [] - for api, methods in self.runtime_config.type_registry["runtime_api"].items(): + for api, methods in runtime.runtime_config.type_registry["runtime_api"].items(): for method in methods["methods"].keys(): call_functions.append( - await self.get_metadata_runtime_call_function(api, method) + await self.get_metadata_runtime_call_function( + api, method, runtime=runtime + ) ) return call_functions async def get_metadata_runtime_call_function( - self, api: str, method: str, block_hash: str = None - ) -> GenericRuntimeCallDefinition: - """ - Get details of a runtime API call - - Args: - api: Name of the runtime API e.g. 'TransactionPaymentApi' - method: Name of the method e.g. 'query_fee_details' - - Returns: - runtime call function - """ - await self.init_runtime(block_hash=block_hash) - - try: - runtime_call_def = self.runtime_config.type_registry["runtime_api"][api][ - "methods" - ][method] - runtime_call_def["api"] = api - runtime_call_def["method"] = method - runtime_api_types = self.runtime_config.type_registry["runtime_api"][ - api - ].get("types", {}) - except KeyError: - raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") - - # Add runtime API types to registry - self.runtime_config.update_type_registry_types(runtime_api_types) - - runtime_call_def_obj = await self.create_scale_object("RuntimeCallDefinition") - runtime_call_def_obj.encode(runtime_call_def) - - return runtime_call_def_obj - - async def get_metadata_runtime_call_function( - self, api: str, method: str + self, + api: str, + method: str, + block_hash: Optional[str] = None, + runtime: Optional[Runtime] = None, ) -> GenericRuntimeCallDefinition: """ - Get details of a runtime API call + Get details of a runtime API call. If not supplying `block_hash` or `runtime`, the runtime of the current block + will be used. Args: api: Name of the runtime API e.g. 'TransactionPaymentApi' method: Name of the method e.g. 'query_fee_details' + block_hash: Hash of the block whose runtime to use, if not specifying `runtime` + runtime: The `Runtime` object whose metadata to use. Returns: GenericRuntimeCallDefinition """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) try: - runtime_call_def = self.runtime_config.type_registry["runtime_api"][api][ + runtime_call_def = runtime.runtime_config.type_registry["runtime_api"][api][ "methods" ][method] runtime_call_def["api"] = api runtime_call_def["method"] = method - runtime_api_types = self.runtime_config.type_registry["runtime_api"][ + runtime_api_types = runtime.runtime_config.type_registry["runtime_api"][ api ].get("types", {}) except KeyError: raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") # Add runtime API types to registry - self.runtime_config.update_type_registry_types(runtime_api_types) + runtime.runtime_config.update_type_registry_types(runtime_api_types) - runtime_call_def_obj = await self.create_scale_object("RuntimeCallDefinition") + runtime_call_def_obj = await self.create_scale_object( + "RuntimeCallDefinition", runtime=runtime + ) runtime_call_def_obj.encode(runtime_call_def) return runtime_call_def_obj @@ -1438,7 +1493,7 @@ async def _get_block_handler( subscription_handler: Optional[Callable[[dict], Awaitable[Any]]] = None, ): try: - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) except BlockNotFound: return None @@ -1453,15 +1508,15 @@ async def decode_block(block_data, block_data_hash=None) -> dict[str, Any]: block_data["header"]["number"], 16 ) - extrinsic_cls = self.runtime_config.get_decoder_class("Extrinsic") + extrinsic_cls = runtime.runtime_config.get_decoder_class("Extrinsic") if "extrinsics" in block_data: for idx, extrinsic_data in enumerate(block_data["extrinsics"]): try: extrinsic_decoder = extrinsic_cls( data=ScaleBytes(extrinsic_data), - metadata=self.runtime.metadata, - runtime_config=self.runtime_config, + metadata=runtime.metadata, + runtime_config=runtime.runtime_config, ) extrinsic_decoder.decode(check_remaining=True) block_data["extrinsics"][idx] = extrinsic_decoder @@ -1475,7 +1530,7 @@ async def decode_block(block_data, block_data_hash=None) -> dict[str, Any]: if isinstance(log_data, str): # Convert digest log from hex (backwards compatibility) try: - log_digest_cls = self.runtime_config.get_decoder_class( + log_digest_cls = runtime.runtime_config.get_decoder_class( "sp_runtime::generic::digest::DigestItem" ) @@ -1492,17 +1547,20 @@ async def decode_block(block_data, block_data_hash=None) -> dict[str, Any]: block_data["header"]["digest"]["logs"][idx] = log_digest if include_author and "PreRuntime" in log_digest.value: - if self.implements_scaleinfo: + if runtime.implements_scaleinfo: engine = bytes(log_digest[1][0]) # Retrieve validator set parent_hash = block_data["header"]["parentHash"] validator_set = await self.query( - "Session", "Validators", block_hash=parent_hash + "Session", + "Validators", + block_hash=parent_hash, + runtime=runtime, ) if engine == b"BABE": babe_predigest = ( - self.runtime_config.create_scale_object( + runtime.runtime_config.create_scale_object( type_string="RawBabePreDigest", data=ScaleBytes( bytes(log_digest[1][1]) @@ -1525,7 +1583,7 @@ async def decode_block(block_data, block_data_hash=None) -> dict[str, Any]: elif engine == b"aura": aura_predigest = ( - self.runtime_config.create_scale_object( + runtime.runtime_config.create_scale_object( type_string="RawAuraPreDigest", data=ScaleBytes( bytes(log_digest[1][1]) @@ -1554,6 +1612,7 @@ async def decode_block(block_data, block_data_hash=None) -> dict[str, Any]: "Session", "Validators", block_hash=block_hash, + runtime=runtime, ) rank_validator = log_digest.value["PreRuntime"][ "data" @@ -1612,19 +1671,24 @@ async def result_handler( ) ], result_handler=result_handler, + runtime=runtime, ) return result["_get_block_handler"][-1] else: if header_only: - response = await self.rpc_request("chain_getHeader", [block_hash]) + response = await self.rpc_request( + "chain_getHeader", [block_hash], runtime=runtime + ) return await decode_block( {"header": response["result"]}, block_data_hash=block_hash ) else: - response = await self.rpc_request("chain_getBlock", [block_hash]) + response = await self.rpc_request( + "chain_getBlock", [block_hash], runtime=runtime + ) return await decode_block( response["result"]["block"], block_data_hash=block_hash ) @@ -1873,11 +1937,22 @@ def convert_event_data(data): attributes = attributes_data if isinstance(attributes, dict): for key, value in attributes.items(): - if isinstance(value, dict): + if key == "who": + who = ss58_encode(bytes(value[0]), self.ss58_format) + attributes["who"] = who + elif key == "from": + who_from = ss58_encode(bytes(value[0]), self.ss58_format) + attributes["from"] = who_from + elif key == "to": + who_to = ss58_encode(bytes(value[0]), self.ss58_format) + attributes["to"] = who_to + elif isinstance(value, dict): # Convert nested single-key dictionaries to their keys as strings - sub_key = next(iter(value.keys())) - if value[sub_key] == (): - attributes[key] = sub_key + for sub_key, sub_value in value.items(): + if isinstance(sub_value, dict): + for sub_sub_key, sub_sub_value in sub_value.items(): + if sub_sub_value == (): + attributes[key][sub_key] = sub_sub_key # Create the converted dictionary converted = { @@ -1899,11 +1974,15 @@ def convert_event_data(data): block_hash = await self.get_chain_head() storage_obj = await self.query( - module="System", storage_function="Events", block_hash=block_hash + module="System", + storage_function="Events", + block_hash=block_hash, + force_legacy_decode=True, ) + # bt-decode Metadata V15 is not ideal for events. Force legacy decoding for this if storage_obj: for item in list(storage_obj): - events.append(convert_event_data(item)) + events.append(item) return events async def get_metadata(self, block_hash=None) -> MetadataV15: @@ -1921,10 +2000,19 @@ async def get_metadata(self, block_hash=None) -> MetadataV15: return runtime.metadata_v15 - async def get_parent_block_hash(self, block_hash): - return await self._parent_hash_fetcher.execute(block_hash) + @cached_fetcher(max_size=512) + async def get_parent_block_hash(self, block_hash) -> str: + """ + Retrieves the block hash of the parent of the given block hash + Args: + block_hash: hash of the block to query - async def _get_parent_block_hash(self, block_hash): + Returns: + Hash of the parent block hash, or the original block hash (if it has not parent) + """ + return await self._get_parent_block_hash(block_hash) + + async def _get_parent_block_hash(self, block_hash) -> str: block_header = await self.rpc_request("chain_getHeader", [block_hash]) if block_header["result"] is None: @@ -1967,25 +2055,27 @@ async def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any: "Unknown error occurred during retrieval of events" ) + @cached_fetcher(max_size=16) async def get_block_runtime_info(self, block_hash: str) -> dict: - return await self._runtime_info_fetcher.execute(block_hash) + """ + Retrieve the runtime info of given block_hash + """ + return await self._get_block_runtime_info(block_hash) get_block_runtime_version = get_block_runtime_info async def _get_block_runtime_info(self, block_hash: str) -> dict: - """ - Retrieve the runtime info of given block_hash - """ response = await self.rpc_request("state_getRuntimeVersion", [block_hash]) return response.get("result") + @cached_fetcher(max_size=512) async def get_block_runtime_version_for(self, block_hash: str): - return await self._runtime_version_for_fetcher.execute(block_hash) - - async def _get_block_runtime_version_for(self, block_hash: str): """ Retrieve the runtime version of the parent of a given block_hash """ + return await self._get_block_runtime_version_for(block_hash) + + async def _get_block_runtime_version_for(self, block_hash: str): parent_block_hash = await self.get_parent_block_hash(block_hash) runtime_info = await self.get_block_runtime_info(parent_block_hash) if runtime_info is None: @@ -2036,13 +2126,16 @@ async def _preprocess( storage_function: str, module: str, raw_storage_key: Optional[bytes] = None, + runtime: Optional[Runtime] = None, ) -> Preprocessed: """ Creates a Preprocessed data object for passing to `_make_rpc_request` """ params = query_for if query_for else [] # Search storage call in metadata - metadata_pallet = self.runtime.metadata.get_metadata_pallet(module) + if not runtime: + runtime = self.runtime + metadata_pallet = runtime.metadata.get_metadata_pallet(module) if not metadata_pallet: raise SubstrateRequestException(f'Pallet "{module}" not found') @@ -2069,16 +2162,16 @@ async def _preprocess( pallet=module, storage_function=storage_function, value_scale_type=value_scale_type, - metadata=self.metadata, - runtime_config=self.runtime_config, + metadata=runtime.metadata, + runtime_config=runtime.runtime_config, ) else: storage_key = StorageKey.create_from_storage_function( module, storage_item.value["name"], params, - runtime_config=self.runtime_config, - metadata=self.runtime.metadata, + runtime_config=runtime.runtime_config, + metadata=runtime.metadata, ) method = "state_getStorageAt" return Preprocessed( @@ -2096,6 +2189,8 @@ async def _process_response( value_scale_type: Optional[str] = None, storage_item: Optional[ScaleType] = None, result_handler: Optional[ResultHandler] = None, + runtime: Optional[Runtime] = None, + force_legacy_decode: bool = False, ) -> tuple[Any, bool]: """ Processes the RPC call response by decoding it, returning it as is, or setting a handler for subscriptions, @@ -2107,6 +2202,8 @@ async def _process_response( value_scale_type: Scale Type string used for decoding ScaleBytes results storage_item: The ScaleType object used for decoding ScaleBytes results result_handler: the result handler coroutine used for handling longer-running subscriptions + runtime: Optional Runtime to use for decoding. If not specified, the currently-loaded `self.runtime` is used + force_legacy_decode: Whether to force the use of the legacy Metadata V14 decoder Returns: (decoded response, completion) @@ -2128,7 +2225,9 @@ async def _process_response( q = bytes(query_value) else: q = query_value - result = await self.decode_scale(value_scale_type, q) + result = await self.decode_scale( + value_scale_type, q, runtime=runtime, force_legacy=force_legacy_decode + ) if asyncio.iscoroutinefunction(result_handler): # For multipart responses as a result of subscriptions. message, bool_result = await result_handler(result, subscription_id) @@ -2142,6 +2241,8 @@ async def _make_rpc_request( storage_item: Optional[ScaleType] = None, result_handler: Optional[ResultHandler] = None, attempt: int = 1, + runtime: Optional[Runtime] = None, + force_legacy_decode: bool = False, ) -> RequestManager.RequestResults: request_manager = RequestManager(payloads) @@ -2185,6 +2286,8 @@ async def _make_rpc_request( value_scale_type, storage_item, result_handler, + runtime=runtime, + force_legacy_decode=force_legacy_decode, ) request_manager.add_response( @@ -2216,6 +2319,7 @@ async def _make_rpc_request( storage_item, result_handler, attempt + 1, + force_legacy_decode, ) return request_manager.get_results() @@ -2244,6 +2348,7 @@ async def rpc_request( result_handler: Optional[ResultHandler] = None, block_hash: Optional[str] = None, reuse_block_hash: bool = False, + runtime: Optional[Runtime] = None, ) -> Any: """ Makes an RPC request to the subtensor. Use this only if `self.query` and `self.query_multiple` and @@ -2257,6 +2362,8 @@ async def rpc_request( hash in the params, and not reusing the block hash reuse_block_hash: whether to reuse the block hash in the params — only mark as True if not supplying the block hash in the params, or via the `block_hash` parameter + runtime: Optional runtime to be used for decoding results of the request. If not specified, the + currently-loaded `self.runtime` is used. Returns: the response from the RPC request @@ -2271,7 +2378,9 @@ async def rpc_request( params + [block_hash] if block_hash else params, ) ] - result = await self._make_rpc_request(payloads, result_handler=result_handler) + result = await self._make_rpc_request( + payloads, result_handler=result_handler, runtime=runtime + ) if "error" in result[payload_id][0]: if "Failed to get runtime version" in ( err_msg := result[payload_id][0]["error"]["message"] @@ -2279,9 +2388,14 @@ async def rpc_request( logger.warning( "Failed to get runtime. Re-fetching from chain, and retrying." ) - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) return await self.rpc_request( - method, params, result_handler, block_hash, reuse_block_hash + method, + params, + result_handler, + block_hash, + reuse_block_hash, + runtime=runtime, ) elif ( "Client error: Api called for an unknown Block: State already discarded" @@ -2296,8 +2410,17 @@ async def rpc_request( else: raise SubstrateRequestException(result[payload_id][0]) + @cached_fetcher(max_size=512) async def get_block_hash(self, block_id: int) -> str: - return await self._block_hash_fetcher.execute(block_id) + """ + Retrieves the hash of the specified block number + Args: + block_id: block number + + Returns: + Hash of the block + """ + return await self._get_block_hash(block_id) async def _get_block_hash(self, block_id: int) -> str: return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"] @@ -2338,10 +2461,10 @@ async def compose_call( if call_params is None: call_params = {} - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) - call = self.runtime_config.create_scale_object( - type_string="Call", metadata=self.runtime.metadata + call = runtime.runtime_config.create_scale_object( + type_string="Call", metadata=runtime.metadata ) call.encode( @@ -2361,6 +2484,7 @@ async def query_multiple( module: str, block_hash: Optional[str] = None, reuse_block_hash: bool = False, + runtime: Optional[Runtime] = None, ) -> dict[str, ScaleType]: """ Queries the subtensor. Only use this when making multiple queries, else use ``self.query`` @@ -2371,10 +2495,13 @@ async def query_multiple( block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash) if block_hash: self.last_block_hash = block_hash - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) preprocessed: tuple[Preprocessed] = await asyncio.gather( *[ - self._preprocess([x], block_hash, storage_function, module) + self._preprocess( + [x], block_hash, storage_function, module, runtime=runtime + ) for x in params ] ) @@ -2387,14 +2514,17 @@ async def query_multiple( storage_item = preprocessed[0].storage_item responses = await self._make_rpc_request( - all_info, value_scale_type, storage_item + all_info, value_scale_type, storage_item, runtime=runtime ) return { param: responses[p.queryable][0] for (param, p) in zip(params, preprocessed) } async def query_multi( - self, storage_keys: list[StorageKey], block_hash: Optional[str] = None + self, + storage_keys: list[StorageKey], + block_hash: Optional[str] = None, + runtime: Optional[Runtime] = None, ) -> list: """ Query multiple storage keys in one request. @@ -2417,15 +2547,20 @@ async def query_multi( Args: storage_keys: list of StorageKey objects block_hash: hash of the block to query against + runtime: Optional `Runtime` to be used for decoding. If not specified, the currently-loaded `self.runtime` + is used. Returns: list of `(storage_key, scale_obj)` tuples """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) # Retrieve corresponding value response = await self.rpc_request( - "state_queryStorageAt", [[s.to_hex() for s in storage_keys], block_hash] + "state_queryStorageAt", + [[s.to_hex() for s in storage_keys], block_hash], + runtime=runtime, ) if "error" in response: @@ -2447,7 +2582,7 @@ async def query_multi( ( storage_key, await self.decode_scale( - storage_key.value_scale_type, change_data + storage_key.value_scale_type, change_data, runtime=runtime ), ), ) @@ -2459,6 +2594,7 @@ async def create_scale_object( type_string: str, data: Optional[ScaleBytes] = None, block_hash: Optional[str] = None, + runtime: Optional[Runtime] = None, **kwargs, ) -> "ScaleType": """ @@ -2469,16 +2605,19 @@ async def create_scale_object( type_string: Name of SCALE type to create data: ScaleBytes: ScaleBytes to decode block_hash: block hash for moment of decoding, when omitted the chain tip will be used + runtime: Optional `Runtime` to use for the creation of the scale object. If not specified, the + currently-loaded `self.runtime` will be used. kwargs: keyword args for the Scale Type constructor Returns: The created Scale Type object """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) if "metadata" not in kwargs: - kwargs["metadata"] = self.runtime.metadata + kwargs["metadata"] = runtime.metadata - return self.runtime.runtime_config.create_scale_object( + return runtime.runtime_config.create_scale_object( type_string, data=data, **kwargs ) @@ -2493,6 +2632,7 @@ async def generate_signature_payload( ) -> ScaleBytes: # Retrieve genesis hash genesis_hash = await self.get_block_hash(0) + runtime = await self.init_runtime(block_hash=None) if not era: era = "00" @@ -2502,7 +2642,7 @@ async def generate_signature_payload( block_hash = genesis_hash else: # Determine mortality of extrinsic - era_obj = self.runtime_config.create_scale_object("Era") + era_obj = runtime.runtime_config.create_scale_object("Era") if isinstance(era, dict) and "current" not in era and "phase" not in era: raise ValueError( @@ -2515,17 +2655,17 @@ async def generate_signature_payload( ) # Create signature payload - signature_payload = self.runtime_config.create_scale_object( + signature_payload = runtime.runtime_config.create_scale_object( "ExtrinsicPayloadValue" ) # Process signed extensions in metadata - if "signed_extensions" in self.runtime.metadata[1][1]["extrinsic"]: + if "signed_extensions" in runtime.metadata[1][1]["extrinsic"]: # Base signature payload signature_payload.type_mapping = [["call", "CallBytes"]] # Add signed extensions to payload - signed_extensions = self.runtime.metadata.get_signed_extensions() + signed_extensions = runtime.metadata.get_signed_extensions() if "CheckMortality" in signed_extensions: signature_payload.type_mapping.append( @@ -2614,10 +2754,10 @@ async def generate_signature_payload( "era": era, "nonce": nonce, "tip": tip, - "spec_version": self.runtime.runtime_version, + "spec_version": runtime.runtime_version, "genesis_hash": genesis_hash, "block_hash": block_hash, - "transaction_version": self.runtime.transaction_version, + "transaction_version": runtime.transaction_version, "asset_id": {"tip": tip, "asset_id": tip_asset_id}, "metadata_hash": None, "mode": "Disabled", @@ -2659,16 +2799,16 @@ async def create_signed_extrinsic( The signed Extrinsic """ # only support creating extrinsics for current block - await self.init_runtime(block_id=await self.get_block_number()) + runtime = await self.init_runtime() # Check requirements if not isinstance(call, GenericCall): raise TypeError("'call' must be of type Call") # Check if extrinsic version is supported - if self.runtime.metadata[1][1]["extrinsic"]["version"] != 4: # type: ignore + if runtime.metadata[1][1]["extrinsic"]["version"] != 4: # type: ignore raise NotImplementedError( - f"Extrinsic version {self.runtime.metadata[1][1]['extrinsic']['version']} not supported" # type: ignore + f"Extrinsic version {runtime.metadata[1][1]['extrinsic']['version']} not supported" # type: ignore ) # Retrieve nonce @@ -2712,7 +2852,7 @@ async def create_signed_extrinsic( # Create extrinsic extrinsic = self.runtime_config.create_scale_object( - type_string="Extrinsic", metadata=self.runtime.metadata + type_string="Extrinsic", metadata=runtime.metadata ) value = { @@ -2729,8 +2869,8 @@ async def create_signed_extrinsic( } # Check if ExtrinsicSignature is MultiSignature, otherwise omit signature_version - signature_cls = self.runtime_config.get_decoder_class("ExtrinsicSignature") - if issubclass(signature_cls, self.runtime_config.get_decoder_class("Enum")): + signature_cls = runtime.runtime_config.get_decoder_class("ExtrinsicSignature") + if issubclass(signature_cls, runtime.runtime_config.get_decoder_class("Enum")): value["signature_version"] = signature_version extrinsic.encode(value) @@ -2787,6 +2927,7 @@ async def _do_runtime_call_old( method: str, params: Optional[Union[list, dict]] = None, block_hash: Optional[str] = None, + runtime: Optional[Runtime] = None, ) -> ScaleType: logger.debug( f"Decoding old runtime call: {api}.{method} with params: {params} at block hash: {block_hash}" @@ -2817,10 +2958,14 @@ async def _do_runtime_call_old( # RPC request result_data = await self.rpc_request( - "state_call", [f"{api}_{method}", param_data.hex(), block_hash] + "state_call", + [f"{api}_{method}", param_data.hex(), block_hash], + runtime=runtime, ) result_vec_u8_bytes = hex_to_bytes(result_data["result"]) - result_bytes = await self.decode_scale("Vec", result_vec_u8_bytes) + result_bytes = await self.decode_scale( + "Vec", result_vec_u8_bytes, runtime=runtime + ) # Decode result # Get correct type @@ -2856,20 +3001,32 @@ async def runtime_call( params = {} try: - metadata_v15_value = runtime.metadata_v15.value() + if runtime.metadata_v15 is None: + _ = self.runtime_config.type_registry["runtime_api"][api]["methods"][ + method + ] + runtime_api_types = self.runtime_config.type_registry["runtime_api"][ + api + ].get("types", {}) + runtime.runtime_config.update_type_registry_types(runtime_api_types) + return await self._do_runtime_call_old( + api, method, params, block_hash, runtime=runtime + ) - apis = {entry["name"]: entry for entry in metadata_v15_value["apis"]} - api_entry = apis[api] - methods = {entry["name"]: entry for entry in api_entry["methods"]} - runtime_call_def = methods[method] + else: + metadata_v15_value = runtime.metadata_v15.value() + + apis = {entry["name"]: entry for entry in metadata_v15_value["apis"]} + api_entry = apis[api] + methods = {entry["name"]: entry for entry in api_entry["methods"]} + runtime_call_def = methods[method] + if _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value): + return await self._do_runtime_call_old( + api, method, params, block_hash, runtime=runtime + ) except KeyError: raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") - if _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value): - result = await self._do_runtime_call_old(api, method, params, block_hash) - - return result - if isinstance(params, list) and len(params) != len(runtime_call_def["inputs"]): raise ValueError( f"Number of parameter provided ({len(params)}) does not " @@ -2892,13 +3049,17 @@ async def runtime_call( # RPC request result_data = await self.rpc_request( - "state_call", [f"{api}_{method}", param_data.hex(), block_hash] + "state_call", + [f"{api}_{method}", param_data.hex(), block_hash], + runtime=runtime, ) output_type_string = f"scale_info::{runtime_call_def['output']}" # Decode result result_bytes = hex_to_bytes(result_data["result"]) - result_obj = ScaleObj(await self.decode_scale(output_type_string, result_bytes)) + result_obj = ScaleObj( + await self.decode_scale(output_type_string, result_bytes, runtime=runtime) + ) return result_obj @@ -2964,7 +3125,7 @@ async def get_metadata_constants(self, block_hash=None) -> list[dict]: constant_list = [] - for module_idx, module in enumerate(self.metadata.pallets): + for module_idx, module in enumerate(runtime.metadata.pallets): for constant in module.constants or []: constant_list.append( self.serialize_constant(constant, module, runtime.runtime_version) @@ -2972,7 +3133,13 @@ async def get_metadata_constants(self, block_hash=None) -> list[dict]: return constant_list - async def get_metadata_constant(self, module_name, constant_name, block_hash=None): + async def get_metadata_constant( + self, + module_name: str, + constant_name: str, + block_hash: Optional[str] = None, + runtime: Optional[Runtime] = None, + ): """ Retrieves the details of a constant for given module name, call function name and block_hash (or chaintip if block_hash is omitted) @@ -2981,13 +3148,15 @@ async def get_metadata_constant(self, module_name, constant_name, block_hash=Non module_name: name of the module you are querying constant_name: name of the constant you are querying block_hash: hash of the block at which to make the runtime API call + runtime: Runtime whose metadata you are querying. Returns: MetadataModuleConstants """ - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) - for module in self.runtime.metadata.pallets: + for module in runtime.metadata.pallets: if module_name == module.name and module.constants: for constant in module.constants: if constant_name == constant.value["name"]: @@ -2999,6 +3168,7 @@ async def get_constant( constant_name: str, block_hash: Optional[str] = None, reuse_block_hash: bool = False, + runtime: Optional[Runtime] = None, ) -> Optional[ScaleObj]: """ Returns the decoded `ScaleType` object of the constant for given module name, call function name and block_hash @@ -3009,18 +3179,22 @@ async def get_constant( constant_name: Name of the constant to query block_hash: Hash of the block at which to make the runtime API call reuse_block_hash: Reuse last-used block hash if set to true + runtime: Runtime to use for querying the constant Returns: ScaleType from the runtime call """ block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash) constant = await self.get_metadata_constant( - module_name, constant_name, block_hash=block_hash + module_name, constant_name, block_hash=block_hash, runtime=runtime ) if constant: # Decode to ScaleType return await self.decode_scale( - constant.type, bytes(constant.constant_value), return_scale_obj=True + constant.type, + bytes(constant.constant_value), + return_scale_obj=True, + runtime=runtime, ) else: return None @@ -3080,14 +3254,14 @@ async def get_type_registry( Returns: dict mapping the type strings to the type decompositions """ - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) - if not self.implements_scaleinfo: + if not runtime.implements_scaleinfo: raise NotImplementedError("MetadataV14 or higher runtimes is required") type_registry = {} - for scale_info_type in self.metadata.portable_registry["types"]: + for scale_info_type in runtime.metadata.portable_registry["types"]: if ( "path" in scale_info_type.value["type"] and len(scale_info_type.value["type"]["path"]) > 0 @@ -3129,21 +3303,21 @@ async def get_metadata_modules(self, block_hash=None) -> list[dict[str, Any]]: Returns: List of metadata modules """ - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) return [ { "metadata_index": idx, "module_id": module.get_identifier(), "name": module.name, - "spec_version": self.runtime.runtime_version, + "spec_version": runtime.runtime_version, "count_call_functions": len(module.calls or []), "count_storage_functions": len(module.storage or []), "count_events": len(module.events or []), "count_constants": len(module.constants or []), "count_errors": len(module.errors or []), } - for idx, module in enumerate(self.metadata.pallets) + for idx, module in enumerate(runtime.metadata.pallets) ] async def get_metadata_module(self, name, block_hash=None) -> ScaleType: @@ -3157,9 +3331,9 @@ async def get_metadata_module(self, name, block_hash=None) -> ScaleType: Returns: MetadataModule """ - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) - return self.metadata.get_metadata_pallet(name) + return runtime.metadata.get_metadata_pallet(name) async def query( self, @@ -3170,6 +3344,8 @@ async def query( raw_storage_key: Optional[bytes] = None, subscription_handler=None, reuse_block_hash: bool = False, + runtime: Optional[Runtime] = None, + force_legacy_decode: bool = False, ) -> Optional[Union["ScaleObj", Any]]: """ Queries substrate. This should only be used when making a single request. For multiple requests, @@ -3178,9 +3354,15 @@ async def query( block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash) if block_hash: self.last_block_hash = block_hash - await self.init_runtime(block_hash=block_hash) + if not runtime: + runtime = await self.init_runtime(block_hash=block_hash) preprocessed: Preprocessed = await self._preprocess( - params, block_hash, storage_function, module, raw_storage_key + params, + block_hash, + storage_function, + module, + raw_storage_key, + runtime=runtime, ) payload = [ self.make_payload( @@ -3195,6 +3377,8 @@ async def query( value_scale_type, storage_item, result_handler=subscription_handler, + runtime=runtime, + force_legacy_decode=force_legacy_decode, ) result = responses[preprocessed.queryable][0] if isinstance(result, (list, tuple, int, float)): @@ -3212,6 +3396,7 @@ async def query_map( page_size: int = 100, ignore_decoding_errors: bool = False, reuse_block_hash: bool = False, + fully_exhaust: bool = False, ) -> AsyncQueryMapResult: """ Iterates over all key-pairs located at the given module and storage_function. The storage @@ -3242,6 +3427,8 @@ async def query_map( decoding reuse_block_hash: use True if you wish to make the query using the last-used block hash. Do not mark True if supplying a block_hash + fully_exhaust: Pull the entire result at once, rather than paginating. Only use if you need the entire query + map result. Returns: AsyncQueryMapResult object @@ -3252,7 +3439,7 @@ async def query_map( self.last_block_hash = block_hash runtime = await self.init_runtime(block_hash=block_hash) - metadata_pallet = self.runtime.metadata.get_metadata_pallet(module) + metadata_pallet = runtime.metadata.get_metadata_pallet(module) if not metadata_pallet: raise ValueError(f'Pallet "{module}" not found') storage_item = metadata_pallet.get_storage_function(storage_function) @@ -3279,8 +3466,8 @@ async def query_map( module, storage_item.value["name"], params, - runtime_config=self.runtime_config, - metadata=self.runtime.metadata, + runtime_config=runtime.runtime_config, + metadata=runtime.metadata, ) prefix = storage_key.to_hex() @@ -3292,10 +3479,16 @@ async def query_map( page_size = max_results # Retrieve storage keys - response = await self.rpc_request( - method="state_getKeysPaged", - params=[prefix, page_size, start_key, block_hash], - ) + if not fully_exhaust: + response = await self.rpc_request( + method="state_getKeysPaged", + params=[prefix, page_size, start_key, block_hash], + runtime=runtime, + ) + else: + response = await self.rpc_request( + method="state_getKeys", params=[prefix, block_hash], runtime=runtime + ) if "error" in response: raise SubstrateRequestException(response["error"]["message"]) @@ -3308,16 +3501,60 @@ async def query_map( if len(result_keys) > 0: last_key = result_keys[-1] - # Retrieve corresponding value - response = await self.rpc_request( - method="state_queryStorageAt", params=[result_keys, block_hash] - ) + # Retrieve corresponding value(s) + if not fully_exhaust: + response = await self.rpc_request( + method="state_queryStorageAt", + params=[result_keys, block_hash], + runtime=runtime, + ) + if "error" in response: + raise SubstrateRequestException(response["error"]["message"]) + for result_group in response["result"]: + result = decode_query_map( + result_group["changes"], + prefix, + runtime, + param_types, + params, + value_type, + key_hashers, + ignore_decoding_errors, + self.decode_ss58, + ) + else: + all_responses = [] + page_batches = [ + result_keys[i : i + page_size] + for i in range(0, len(result_keys), page_size) + ] + changes = [] + for batch_group in [ + # run five concurrent batch pulls; could go higher, but it's good to be a good citizens + # of the ecosystem + page_batches[i : i + 5] + for i in range(0, len(page_batches), 5) + ]: + all_responses.extend( + await asyncio.gather( + *[ + self.rpc_request( + method="state_queryStorageAt", + params=[batch_keys, block_hash], + runtime=runtime, + ) + for batch_keys in batch_group + ] + ) + ) + for response in all_responses: + if "error" in response: + raise SubstrateRequestException(response["error"]["message"]) + for result_group in response["result"]: + changes.extend(result_group["changes"]) - if "error" in response: - raise SubstrateRequestException(response["error"]["message"]) - for result_group in response["result"]: result = decode_query_map( - result_group["changes"], + changes, prefix, runtime, param_types, @@ -3325,6 +3562,7 @@ async def query_map( value_type, key_hashers, ignore_decoding_errors, + self.decode_ss58, ) return AsyncQueryMapResult( records=result, @@ -3566,9 +3804,9 @@ async def get_metadata_call_function( Returns: list of call functions """ - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) - for pallet in self.runtime.metadata.pallets: + for pallet in runtime.metadata.pallets: if pallet.name == module_name and pallet.calls: for call in pallet.calls: if call.name == call_function_name: @@ -3590,7 +3828,7 @@ async def get_metadata_events(self, block_hash=None) -> list[dict]: event_list = [] - for event_index, (module, event) in self.metadata.event_index.items(): + for event_index, (module, event) in runtime.metadata.event_index.items(): event_list.append( self.serialize_module_event( module, event, runtime.runtime_version, event_index diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index b5148a8..b7c4c15 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -13,17 +13,15 @@ ss58_encode, MultiAccountId, ) -from scalecodec.base import RuntimeConfigurationObject, ScaleBytes, ScaleType +from scalecodec.base import ScaleBytes, ScaleType from websockets.sync.client import connect, ClientConnection from websockets.exceptions import ConnectionClosed -from async_substrate_interface.const import SS58_FORMAT from async_substrate_interface.errors import ( ExtrinsicNotFound, SubstrateRequestException, BlockNotFound, MaxRetriesExceeded, - MetadataAtVersionNotFound, StateDiscardedError, ) from async_substrate_interface.protocols import Keypair @@ -45,6 +43,8 @@ _determine_if_old_runtime_call, _bt_decode_to_dict_or_list, decode_query_map, + legacy_scale_decode, + convert_account_ids, ) from async_substrate_interface.utils.storage import StorageKey from async_substrate_interface.type_registry import _TYPE_REGISTRY @@ -256,16 +256,17 @@ def process_events(self): self.__weight = dispatch_info["weight"] if "Module" in dispatch_error: - module_index = dispatch_error["Module"][0]["index"] - error_index = int.from_bytes( - bytes(dispatch_error["Module"][0]["error"]), - byteorder="little", - signed=False, - ) + if isinstance(dispatch_error["Module"], tuple): + module_index = dispatch_error["Module"][0] + error_index = dispatch_error["Module"][1] + else: + module_index = dispatch_error["Module"]["index"] + error_index = dispatch_error["Module"]["error"] if isinstance(error_index, str): # Actual error index is first u8 in new [u8; 4] format error_index = int(error_index[2:4], 16) + module_error = self.substrate.metadata.get_module_error( module_index=module_index, error_index=error_index ) @@ -487,6 +488,7 @@ def __init__( retry_timeout: float = 60.0, _mock: bool = False, _log_raw_websockets: bool = False, + decode_ss58: bool = False, ): """ The sync compatible version of the subtensor interface commands we use in bittensor. Use this instance only @@ -504,8 +506,16 @@ def __init__( retry_timeout: how to long wait since the last ping to retry the RPC request _mock: whether to use mock version of the subtensor interface _log_raw_websockets: whether to log raw websocket requests during RPC requests + decode_ss58: Whether to decode AccountIds to SS58 or leave them in raw bytes tuples. """ + super().__init__( + type_registry, + type_registry_preset, + use_remote_preset, + ss58_format, + decode_ss58, + ) self.max_retries = max_retries self.retry_timeout = retry_timeout self.chain_endpoint = url @@ -518,17 +528,10 @@ def __init__( "strict_scale_decode": True, } self.initialized = False - self.ss58_format = ss58_format self.type_registry = type_registry self.type_registry_preset = type_registry_preset self.runtime_cache = RuntimeCache() - self.runtime_config = RuntimeConfigurationObject( - ss58_format=self.ss58_format, implements_scale_info=True - ) self.metadata_version_hex = "0x0f000000" # v15 - self.reload_type_registry() - self.registry_type_map = {} - self.type_id_to_name = {} self._mock = _mock self.log_raw_websockets = _log_raw_websockets if not _mock: @@ -558,11 +561,30 @@ def initialize(self): chain = self.rpc_request("system_chain", []) self._chain = chain.get("result") self.init_runtime() + if self.ss58_format is None: + # Check and apply runtime constants + ss58_prefix_constant = self.get_constant( + "System", "SS58Prefix", block_hash=self.last_block_hash + ) + if ss58_prefix_constant: + self.ss58_format = ss58_prefix_constant.value + self.runtime.ss58_format = ss58_prefix_constant.value + self.runtime.runtime_config.ss58_format = ss58_prefix_constant.value self.initialized = True def __exit__(self, exc_type, exc_val, exc_tb): self.ws.close() + @property + def metadata(self): + if not self.runtime or self.runtime.metadata is None: + raise AttributeError( + "Metadata not found. This generally indicates that the AsyncSubstrateInterface object " + "is not properly async initialized." + ) + else: + return self.runtime.metadata + @property def properties(self): if self._properties is None: @@ -645,14 +667,13 @@ def _load_registry_at_block(self, block_hash: Optional[str]) -> MetadataV15: "Client error: Execution failed: Other: Exported method Metadata_metadata_at_version is not found" in e.args ): - raise MetadataAtVersionNotFound + return None, None else: raise e metadata_option_hex_str = metadata_rpc_result["result"] metadata_option_bytes = bytes.fromhex(metadata_option_hex_str[2:]) metadata = MetadataV15.decode_from_metadata_option(metadata_option_bytes) registry = PortableRegistry.from_metadata_v15(metadata) - self._load_registry_type_map(registry) return metadata, registry def decode_scale( @@ -660,6 +681,7 @@ def decode_scale( type_string: str, scale_bytes: bytes, return_scale_obj=False, + force_legacy: bool = False, ) -> Union[ScaleObj, Any]: """ Helper function to decode arbitrary SCALE-bytes (e.g. 0x02000000) according to given RUST type_string @@ -670,15 +692,30 @@ def decode_scale( type_string: the type string of the SCALE object for decoding scale_bytes: the bytes representation of the SCALE object to decode return_scale_obj: Whether to return the decoded value wrapped in a SCALE-object-like wrapper, or raw. + force_legacy: Whether to force the use of the legacy Metadata V14 decoder Returns: Decoded object """ if type_string == "scale_info::0": # Is an AccountId # Decode AccountId bytes to SS58 address - return ss58_encode(scale_bytes, SS58_FORMAT) + return ss58_encode(scale_bytes, self.ss58_format) else: - obj = decode_by_type_string(type_string, self.runtime.registry, scale_bytes) + if self.runtime.metadata_v15 is not None and force_legacy is False: + obj = decode_by_type_string( + type_string, self.runtime.registry, scale_bytes + ) + if self.decode_ss58: + try: + type_str_int = int(type_string.split("::")[1]) + decoded_type_str = self.runtime.type_id_to_name[type_str_int] + obj = convert_account_ids( + obj, decoded_type_str, self.ss58_format + ) + except (ValueError, KeyError): + pass + else: + obj = legacy_scale_decode(type_string, scale_bytes, self.runtime) if return_scale_obj: return ScaleObj(obj) else: @@ -688,21 +725,21 @@ def load_runtime(self, runtime): self.runtime = runtime # Update type registry - self.reload_type_registry(use_remote_preset=False, auto_discover=True) + self.runtime.reload_type_registry(use_remote_preset=False, auto_discover=True) self.runtime_config.set_active_spec_version_id(runtime.runtime_version) - if self.implements_scaleinfo: + if self.runtime.implements_scaleinfo: logger.debug("Add PortableRegistry from metadata to type registry") self.runtime_config.add_portable_registry(runtime.metadata) # Set runtime compatibility flags try: _ = self.runtime_config.create_scale_object("sp_weights::weight_v2::Weight") - self.config["is_weight_v2"] = True + self.runtime.config["is_weight_v2"] = True self.runtime_config.update_type_registry_types( {"Weight": "sp_weights::weight_v2::Weight"} ) except NotImplementedError: - self.config["is_weight_v2"] = False + self.runtime.config["is_weight_v2"] = False self.runtime_config.update_type_registry_types({"Weight": "WeightV1"}) def init_runtime( @@ -727,11 +764,25 @@ def init_runtime( if block_id and block_hash: raise ValueError("Cannot provide block_hash and block_id at the same time") - if block_id: + if block_id is not None: + if runtime := self.runtime_cache.retrieve(block=block_id): + self.runtime = runtime + self.runtime.load_runtime() + if self.runtime.registry: + self.runtime.load_registry_type_map() + return self.runtime block_hash = self.get_block_hash(block_id) if not block_hash: block_hash = self.get_chain_head() + else: + self.last_block_hash = block_hash + if runtime := self.runtime_cache.retrieve(block_hash=block_hash): + self.runtime = runtime + self.runtime.load_runtime() + if self.runtime.registry: + self.runtime.load_registry_type_map() + return self.runtime runtime_version = self.get_block_runtime_version_for(block_hash) if runtime_version is None: @@ -742,58 +793,76 @@ def init_runtime( if self.runtime and runtime_version == self.runtime.runtime_version: return self.runtime - runtime = self.runtime_cache.retrieve(runtime_version=runtime_version) - if not runtime: - self.last_block_hash = block_hash + if runtime := self.runtime_cache.retrieve(runtime_version=runtime_version): + self.runtime = runtime + self.runtime.load_runtime() + if self.runtime.registry: + self.runtime.load_registry_type_map() + return runtime + else: + self.runtime = self.get_runtime_for_version(runtime_version, block_hash) + self.runtime.load_runtime() + if self.runtime.registry: + self.runtime.load_registry_type_map() + return self.runtime - runtime_block_hash = self.get_parent_block_hash(block_hash) + def get_runtime_for_version( + self, runtime_version: int, block_hash: Optional[str] = None + ) -> Runtime: + """ + Retrieves the `Runtime` for a given runtime version at a given block hash. + Args: + runtime_version: version of the runtime (from `get_block_runtime_version_for`) + block_hash: hash of the block to query - runtime_info = self.get_block_runtime_info(runtime_block_hash) + Returns: + Runtime object for the given runtime version + """ + if not block_hash: + block_hash = self.get_chain_head() + runtime_block_hash = self.get_parent_block_hash(block_hash) + block_number = self.get_block_number(block_hash) + runtime_info = self.get_block_runtime_info(runtime_block_hash) - metadata = self.get_block_metadata( - block_hash=runtime_block_hash, decode=True - ) - if metadata is None: - # does this ever happen? - raise SubstrateRequestException( - f"No metadata for block '{runtime_block_hash}'" - ) - logger.debug( - "Retrieved metadata for {} from Substrate node".format(runtime_version) + metadata = self.get_block_metadata(block_hash=runtime_block_hash, decode=True) + if metadata is None: + # does this ever happen? + raise SubstrateRequestException( + f"No metadata for block '{runtime_block_hash}'" ) + logger.debug( + "Retrieved metadata for {} from Substrate node".format(runtime_version) + ) - metadata_v15, registry = self._load_registry_at_block( - block_hash=runtime_block_hash - ) + metadata_v15, registry = self._load_registry_at_block( + block_hash=runtime_block_hash + ) + if metadata_v15 is not None: logger.debug( - "Retrieved metadata v15 for {} from Substrate node".format( - runtime_version - ) + f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node" ) - - runtime = Runtime( - chain=self.chain, - runtime_config=self.runtime_config, - metadata=metadata, - type_registry=self.type_registry, - metadata_v15=metadata_v15, - runtime_info=runtime_info, - registry=registry, - ) - self.runtime_cache.add_item( - runtime_version=runtime_version, runtime=runtime - ) - - self.load_runtime(runtime) - - if self.ss58_format is None: - # Check and apply runtime constants - ss58_prefix_constant = self.get_constant( - "System", "SS58Prefix", block_hash=block_hash + else: + logger.debug( + f"Exported method Metadata_metadata_at_version is not found for {runtime_version}. This indicates the " + f"block is quite old, decoding for this block will use legacy Python decoding." ) - if ss58_prefix_constant: - self.ss58_format = ss58_prefix_constant + runtime = Runtime( + chain=self.chain, + runtime_config=self.runtime_config, + metadata=metadata, + type_registry=self.type_registry, + metadata_v15=metadata_v15, + runtime_info=runtime_info, + registry=registry, + ss58_format=self.ss58_format, + ) + self.runtime_cache.add_item( + block=block_number, + block_hash=block_hash, + runtime_version=runtime_version, + runtime=runtime, + ) return runtime def create_storage_key( @@ -1069,6 +1138,7 @@ def get_metadata_runtime_call_function( Args: api: Name of the runtime API e.g. 'TransactionPaymentApi' method: Name of the method e.g. 'query_fee_details' + block_hash: block hash whose metadata to query Returns: runtime call function @@ -1095,41 +1165,6 @@ def get_metadata_runtime_call_function( return runtime_call_def_obj - def get_metadata_runtime_call_function( - self, api: str, method: str - ) -> GenericRuntimeCallDefinition: - """ - Get details of a runtime API call - - Args: - api: Name of the runtime API e.g. 'TransactionPaymentApi' - method: Name of the method e.g. 'query_fee_details' - - Returns: - GenericRuntimeCallDefinition - """ - self.init_runtime() - - try: - runtime_call_def = self.runtime_config.type_registry["runtime_api"][api][ - "methods" - ][method] - runtime_call_def["api"] = api - runtime_call_def["method"] = method - runtime_api_types = self.runtime_config.type_registry["runtime_api"][ - api - ].get("types", {}) - except KeyError: - raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") - - # Add runtime API types to registry - self.runtime_config.update_type_registry_types(runtime_api_types) - - runtime_call_def_obj = self.create_scale_object("RuntimeCallDefinition") - runtime_call_def_obj.encode(runtime_call_def) - - return runtime_call_def_obj - def _get_block_handler( self, block_hash: str, @@ -1194,7 +1229,7 @@ def decode_block(block_data, block_data_hash=None) -> dict[str, Any]: block_data["header"]["digest"]["logs"][idx] = log_digest if include_author and "PreRuntime" in log_digest.value: - if self.implements_scaleinfo: + if self.runtime.implements_scaleinfo: engine = bytes(log_digest[1][0]) # Retrieve validator set parent_hash = block_data["header"]["parentHash"] @@ -1569,11 +1604,16 @@ def convert_event_data(data): attributes = attributes_data if isinstance(attributes, dict): for key, value in attributes.items(): + if key == "who": + who = ss58_encode(bytes(value[0]), self.ss58_format) + attributes["who"] = who if isinstance(value, dict): # Convert nested single-key dictionaries to their keys as strings - sub_key = next(iter(value.keys())) - if value[sub_key] == (): - attributes[key] = sub_key + for sub_key, sub_value in value.items(): + if isinstance(sub_value, dict): + for sub_sub_key, sub_sub_value in sub_value.items(): + if sub_sub_value == (): + attributes[key][sub_key] = sub_sub_key # Create the converted dictionary converted = { @@ -1595,11 +1635,15 @@ def convert_event_data(data): block_hash = self.get_chain_head() storage_obj = self.query( - module="System", storage_function="Events", block_hash=block_hash + module="System", + storage_function="Events", + block_hash=block_hash, + force_legacy_decode=True, ) + # bt-decode Metadata V15 is not ideal for events. Force legacy decoding for this if storage_obj: for item in list(storage_obj): - events.append(convert_event_data(item)) + events.append(item) return events def get_metadata(self, block_hash=None) -> MetadataV15: @@ -1781,6 +1825,7 @@ def _process_response( value_scale_type: Optional[str] = None, storage_item: Optional[ScaleType] = None, result_handler: Optional[ResultHandler] = None, + force_legacy_decode: bool = False, ) -> tuple[Any, bool]: """ Processes the RPC call response by decoding it, returning it as is, or setting a handler for subscriptions, @@ -1792,6 +1837,7 @@ def _process_response( value_scale_type: Scale Type string used for decoding ScaleBytes results storage_item: The ScaleType object used for decoding ScaleBytes results result_handler: the result handler coroutine used for handling longer-running subscriptions + force_legacy_decode: Whether to force legacy Metadata V14 decoding of the response Returns: (decoded response, completion) @@ -1813,7 +1859,9 @@ def _process_response( q = bytes(query_value) else: q = query_value - result = self.decode_scale(value_scale_type, q) + result = self.decode_scale( + value_scale_type, q, force_legacy=force_legacy_decode + ) if isinstance(result_handler, Callable): # For multipart responses as a result of subscriptions. message, bool_result = result_handler(result, subscription_id) @@ -1827,6 +1875,7 @@ def _make_rpc_request( storage_item: Optional[ScaleType] = None, result_handler: Optional[ResultHandler] = None, attempt: int = 1, + force_legacy_decode: bool = False, ) -> RequestManager.RequestResults: request_manager = RequestManager(payloads) _received = {} @@ -1860,6 +1909,7 @@ def _make_rpc_request( storage_item, result_handler, attempt + 1, + force_legacy_decode, ) if "id" in response: _received[response["id"]] = response @@ -1891,6 +1941,7 @@ def _make_rpc_request( value_scale_type, storage_item, result_handler, + force_legacy_decode, ) request_manager.add_response( item_id, decoded_response, complete @@ -2520,20 +2571,28 @@ def runtime_call( params = {} try: - metadata_v15_value = runtime.metadata_v15.value() + if runtime.metadata_v15 is None: + _ = self.runtime_config.type_registry["runtime_api"][api]["methods"][ + method + ] + runtime_api_types = self.runtime_config.type_registry["runtime_api"][ + api + ].get("types", {}) + runtime.runtime_config.update_type_registry_types(runtime_api_types) + return self._do_runtime_call_old(api, method, params, block_hash) + else: + metadata_v15_value = runtime.metadata_v15.value() + + apis = {entry["name"]: entry for entry in metadata_v15_value["apis"]} + api_entry = apis[api] + methods = {entry["name"]: entry for entry in api_entry["methods"]} + runtime_call_def = methods[method] + if _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value): + return self._do_runtime_call_old(api, method, params, block_hash) - apis = {entry["name"]: entry for entry in metadata_v15_value["apis"]} - api_entry = apis[api] - methods = {entry["name"]: entry for entry in api_entry["methods"]} - runtime_call_def = methods[method] except KeyError: raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") - if _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value): - result = self._do_runtime_call_old(api, method, params, block_hash) - - return result - if isinstance(params, list) and len(params) != len(runtime_call_def["inputs"]): raise ValueError( f"Number of parameter provided ({len(params)}) does not " @@ -2545,13 +2604,15 @@ def runtime_call( for idx, param in enumerate(runtime_call_def["inputs"]): param_type_string = f"scale_info::{param['ty']}" if isinstance(params, list): - param_data += self.encode_scale(param_type_string, params[idx]) + param_data += self.encode_scale( + param_type_string, params[idx], runtime=runtime + ) else: if param["name"] not in params: raise ValueError(f"Runtime Call param '{param['name']}' is missing") param_data += self.encode_scale( - param_type_string, params[param["name"]] + param_type_string, params[param["name"]], runtime=runtime ) # RPC request @@ -2733,7 +2794,7 @@ def get_type_registry(self, block_hash: str = None, max_recursion: int = 4) -> d """ self.init_runtime(block_hash=block_hash) - if not self.implements_scaleinfo: + if not self.runtime.implements_scaleinfo: raise NotImplementedError("MetadataV14 or higher runtimes is required") type_registry = {} @@ -2819,6 +2880,7 @@ def query( raw_storage_key: Optional[bytes] = None, subscription_handler=None, reuse_block_hash: bool = False, + force_legacy_decode: bool = False, ) -> Optional[Union["ScaleObj", Any]]: """ Queries substrate. This should only be used when making a single request. For multiple requests, @@ -2844,6 +2906,7 @@ def query( value_scale_type, storage_item, result_handler=subscription_handler, + force_legacy_decode=force_legacy_decode, ) result = responses[preprocessed.queryable][0] if isinstance(result, (list, tuple, int, float)): @@ -2895,7 +2958,6 @@ def query_map( Returns: QueryMapResult object """ - hex_to_bytes_ = hex_to_bytes params = params or [] block_hash = self._get_current_block_hash(block_hash, reuse_block_hash) if block_hash: @@ -2977,6 +3039,7 @@ def query_map( value_type, key_hashers, ignore_decoding_errors, + self.decode_ss58, ) return QueryMapResult( records=result, diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index e29e30c..57681f3 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -15,19 +15,33 @@ from .const import SS58_FORMAT from .utils import json - logger = logging.getLogger("async_substrate_interface") class RuntimeCache: + """ + Cache that holds all the Runtime objects used by AsyncSubstrateInterface and SubstrateInterface. See the docstring + for Runtime for more information about Runtime objects specifically. + + For SubstrateInterface (sync), this serves purely as a quick way of retrieving a previously loaded Runtime. For + AsyncSubstrateInterface, this is very important, as, while it does the same as for SubstrateInterface, it also + serves as an easy way for a user to fetch a Runtime whose registry or metadata they wish to utilize in some way. + + The `last_used` attribute is always updated with the most recently inserted or retrieved Runtime object. If you're + querying numerous blocks at once with different runtimes, and you wish to use the metadata or registry directly, it + is important you are utilizing the correct version. + """ + blocks: dict[int, "Runtime"] block_hashes: dict[str, "Runtime"] versions: dict[int, "Runtime"] + last_used: Optional["Runtime"] def __init__(self): self.blocks = {} self.block_hashes = {} self.versions = {} + self.last_used = None def add_item( self, @@ -35,7 +49,11 @@ def add_item( block: Optional[int] = None, block_hash: Optional[str] = None, runtime_version: Optional[int] = None, - ): + ) -> None: + """ + Adds a Runtime object to the cache mapped to its version, block number, and/or block hash. + """ + self.last_used = runtime if block is not None: self.blocks[block] = runtime if block_hash is not None: @@ -49,18 +67,35 @@ def retrieve( block_hash: Optional[str] = None, runtime_version: Optional[int] = None, ) -> Optional["Runtime"]: + """ + Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version. + Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`. + """ if block is not None: - return self.blocks.get(block) - elif block_hash is not None: - return self.block_hashes.get(block_hash) - elif runtime_version is not None: - return self.versions.get(runtime_version) - else: - return None + runtime = self.blocks.get(block) + if runtime is not None: + self.last_used = runtime + return runtime + if block_hash is not None: + runtime = self.block_hashes.get(block_hash) + if runtime is not None: + self.last_used = runtime + return runtime + if runtime_version is not None: + runtime = self.versions.get(runtime_version) + if runtime is not None: + self.last_used = runtime + return runtime + return None class Runtime: - runtime_version = None + """ + The Runtime object holds the necessary metadata and registry information required to do necessary scale encoding and + decoding. Currently only Metadata V15 is supported for decoding, though we plan to release legacy decoding options. + """ + + runtime_version: Optional[int] = None transaction_version = None cache_region = None metadata = None @@ -69,17 +104,21 @@ class Runtime: runtime_info = None type_registry_preset = None registry: Optional[PortableRegistry] = None + registry_type_map: dict[str, int] + type_id_to_name: dict[int, str] def __init__( self, - chain, + chain: str, runtime_config: RuntimeConfigurationObject, metadata, type_registry, metadata_v15=None, runtime_info=None, registry=None, + ss58_format=SS58_FORMAT, ): + self.ss58_format = ss58_format self.config = {} self.chain = chain self.type_registry = type_registry @@ -90,97 +129,230 @@ def __init__( self.registry = registry self.runtime_version = runtime_info.get("specVersion") self.transaction_version = runtime_info.get("transactionVersion") + self.load_runtime() + if registry is not None: + self.load_registry_type_map() + + def load_runtime(self): + """ + Initial loading of the runtime's type registry information. + """ + # Update type registry + self.reload_type_registry(use_remote_preset=False, auto_discover=True) + + self.runtime_config.set_active_spec_version_id(self.runtime_version) + if self.implements_scaleinfo: + logger.debug("Adding PortableRegistry from metadata to type registry") + self.runtime_config.add_portable_registry(self.metadata) + # Set runtime compatibility flags + try: + _ = self.runtime_config.create_scale_object("sp_weights::weight_v2::Weight") + self.config["is_weight_v2"] = True + self.runtime_config.update_type_registry_types( + {"Weight": "sp_weights::weight_v2::Weight"} + ) + except NotImplementedError: + self.config["is_weight_v2"] = False + self.runtime_config.update_type_registry_types({"Weight": "WeightV1"}) + + @property + def implements_scaleinfo(self) -> Optional[bool]: + """ + Returns True if current runtime implements a `PortableRegistry` (`MetadataV14` and higher) + """ + if self.metadata: + return self.metadata.portable_registry is not None + else: + return None def __str__(self): return f"Runtime: {self.chain} | {self.config}" + def reload_type_registry( + self, use_remote_preset: bool = True, auto_discover: bool = True + ): + """ + Reload type registry and preset used to instantiate the SubstrateInterface object. Useful to periodically apply + changes in type definitions when a runtime upgrade occurred + + Args: + use_remote_preset: When True preset is downloaded from Github master, otherwise use files from local + installed scalecodec package + auto_discover: Whether to automatically discover the type registry presets based on the chain name and the + type registry + """ + self.runtime_config.clear_type_registry() + + self.runtime_config.implements_scale_info = self.implements_scaleinfo + + # Load metadata types in runtime configuration + self.runtime_config.update_type_registry(load_type_registry_preset(name="core")) + self.apply_type_registry_presets( + use_remote_preset=use_remote_preset, auto_discover=auto_discover + ) + + def apply_type_registry_presets( + self, + use_remote_preset: bool = True, + auto_discover: bool = True, + ): + """ + Applies type registry presets to the runtime + + Args: + use_remote_preset: whether to use presets from remote + auto_discover: whether to use presets from local installed scalecodec package + """ + if self.type_registry_preset is not None: + # Load type registry according to preset + type_registry_preset_dict = load_type_registry_preset( + name=self.type_registry_preset, use_remote_preset=use_remote_preset + ) + + if not type_registry_preset_dict: + raise ValueError( + f"Type registry preset '{self.type_registry_preset}' not found" + ) + + elif auto_discover: + # Try to auto discover type registry preset by chain name + type_registry_name = self.chain.lower().replace(" ", "-") + try: + type_registry_preset_dict = load_type_registry_preset( + type_registry_name + ) + self.type_registry_preset = type_registry_name + except ValueError: + type_registry_preset_dict = None + + else: + type_registry_preset_dict = None + + if type_registry_preset_dict: + # Load type registries in runtime configuration + if self.implements_scaleinfo is False: + # Only runtime with no embedded types in metadata need the default set of explicit defined types + self.runtime_config.update_type_registry( + load_type_registry_preset( + "legacy", use_remote_preset=use_remote_preset + ) + ) + + if self.type_registry_preset != "legacy": + self.runtime_config.update_type_registry(type_registry_preset_dict) + + if self.type_registry: + # Load type registries in runtime configuration + self.runtime_config.update_type_registry(self.type_registry) + + def load_registry_type_map(self) -> None: + """ + Loads the runtime's type mapping according to registry + """ + registry_type_map = {} + type_id_to_name = {} + types = json.loads(self.registry.registry)["types"] + type_by_id = {entry["id"]: entry for entry in types} + + # Pass 1: Gather simple types + for type_entry in types: + type_id = type_entry["id"] + type_def = type_entry["type"]["def"] + type_path = type_entry["type"].get("path") + if type_entry.get("params") or "variant" in type_def: + continue + if type_path: + type_name = type_path[-1] + registry_type_map[type_name] = type_id + type_id_to_name[type_id] = type_name + else: + # Possibly a primitive + if "primitive" in type_def: + prim_name = type_def["primitive"] + registry_type_map[prim_name] = type_id + type_id_to_name[type_id] = prim_name + + # Pass 2: Resolve remaining types + pending_ids = set(type_by_id.keys()) - set(type_id_to_name.keys()) + + def resolve_type_definition(type_id_): + type_entry_ = type_by_id[type_id_] + type_def_ = type_entry_["type"]["def"] + type_path_ = type_entry_["type"].get("path", []) + type_params = type_entry_["type"].get("params", []) + + if type_id_ in type_id_to_name: + return type_id_to_name[type_id_] + + # Resolve complex types with paths (including generics like Option etc) + if type_path_: + type_name_ = type_path_[-1] + if type_params: + inner_names = [] + for param in type_params: + dep_id = param["type"] + if dep_id not in type_id_to_name: + return None + inner_names.append(type_id_to_name[dep_id]) + return f"{type_name_}<{', '.join(inner_names)}>" + if "variant" in type_def_: + return None + return type_name_ + + elif "sequence" in type_def_: + sequence_type_id = type_def_["sequence"]["type"] + inner_type = type_id_to_name.get(sequence_type_id) + if inner_type: + type_name_ = f"Vec<{inner_type}>" + return type_name_ + + elif "array" in type_def_: + array_type_id = type_def_["array"]["type"] + inner_type = type_id_to_name.get(array_type_id) + maybe_len = type_def_["array"].get("len") + if inner_type: + if maybe_len: + type_name_ = f"[{inner_type}; {maybe_len}]" + else: + type_name_ = f"[{inner_type}]" + return type_name_ -# @property -# def implements_scaleinfo(self) -> bool: -# """ -# Returns True if current runtime implementation a `PortableRegistry` (`MetadataV14` and higher) -# """ -# if self.metadata: -# return self.metadata.portable_registry is not None -# else: -# return False -# -# def reload_type_registry( -# self, use_remote_preset: bool = True, auto_discover: bool = True -# ): -# """ -# Reload type registry and preset used to instantiate the SubstrateInterface object. Useful to periodically apply -# changes in type definitions when a runtime upgrade occurred -# -# Args: -# use_remote_preset: When True preset is downloaded from Github master, otherwise use files from local -# installed scalecodec package -# auto_discover: Whether to automatically discover the type registry presets based on the chain name and the -# type registry -# """ -# self.runtime_config.clear_type_registry() -# -# self.runtime_config.implements_scale_info = self.implements_scaleinfo -# -# # Load metadata types in runtime configuration -# self.runtime_config.update_type_registry(load_type_registry_preset(name="core")) -# self.apply_type_registry_presets( -# use_remote_preset=use_remote_preset, auto_discover=auto_discover -# ) -# -# def apply_type_registry_presets( -# self, -# use_remote_preset: bool = True, -# auto_discover: bool = True, -# ): -# """ -# Applies type registry presets to the runtime -# -# Args: -# use_remote_preset: whether to use presets from remote -# auto_discover: whether to use presets from local installed scalecodec package -# """ -# if self.type_registry_preset is not None: -# # Load type registry according to preset -# type_registry_preset_dict = load_type_registry_preset( -# name=self.type_registry_preset, use_remote_preset=use_remote_preset -# ) -# -# if not type_registry_preset_dict: -# raise ValueError( -# f"Type registry preset '{self.type_registry_preset}' not found" -# ) -# -# elif auto_discover: -# # Try to auto discover type registry preset by chain name -# type_registry_name = self.chain.lower().replace(" ", "-") -# try: -# type_registry_preset_dict = load_type_registry_preset( -# type_registry_name -# ) -# self.type_registry_preset = type_registry_name -# except ValueError: -# type_registry_preset_dict = None -# -# else: -# type_registry_preset_dict = None -# -# if type_registry_preset_dict: -# # Load type registries in runtime configuration -# if self.implements_scaleinfo is False: -# # Only runtime with no embedded types in metadata need the default set of explicit defined types -# self.runtime_config.update_type_registry( -# load_type_registry_preset( -# "legacy", use_remote_preset=use_remote_preset -# ) -# ) -# -# if self.type_registry_preset != "legacy": -# self.runtime_config.update_type_registry(type_registry_preset_dict) -# -# if self.type_registry: -# # Load type registries in runtime configuration -# self.runtime_config.update_type_registry(self.type_registry) + elif "compact" in type_def_: + compact_type_id = type_def_["compact"]["type"] + inner_type = type_id_to_name.get(compact_type_id) + if inner_type: + type_name_ = f"Compact<{inner_type}>" + return type_name_ + + elif "tuple" in type_def_: + tuple_type_ids = type_def_["tuple"] + type_names = [] + for inner_type_id in tuple_type_ids: + if inner_type_id not in type_id_to_name: + return None + type_names.append(type_id_to_name[inner_type_id]) + type_name_ = ", ".join(type_names) + type_name_ = f"({type_name_})" + return type_name_ + + elif "variant" in type_def_: + return None + + return None + + resolved_type = True + while resolved_type and pending_ids: + resolved_type = False + for type_id in list(pending_ids): + name = resolve_type_definition(type_id) + if name is not None: + type_id_to_name[type_id] = name + registry_type_map[name] = type_id + pending_ids.remove(type_id) + resolved_type = True + + self.registry_type_map = registry_type_map + self.type_id_to_name = type_id_to_name class RequestManager: @@ -373,40 +545,61 @@ class SubstrateMixin(ABC): type_registry: Optional[dict] ss58_format: Optional[int] ws_max_size = 2**32 - registry_type_map: dict[str, int] - type_id_to_name: dict[int, str] - runtime: Runtime = None + runtime: Runtime = None # TODO remove - @property - def chain(self): - """ - Returns the substrate chain currently associated with object - """ - return self._chain - - @property - def metadata(self): - if not self.runtime or self.runtime.metadata is None: - raise AttributeError( - "Metadata not found. This generally indicates that the AsyncSubstrateInterface object " - "is not properly async initialized." + def __init__( + self, + type_registry: Optional[dict] = None, + type_registry_preset: Optional[str] = None, + use_remote_preset: bool = False, + ss58_format: Optional[int] = None, + decode_ss58: bool = False, + ): + # We load a very basic RuntimeConfigurationObject that is only used for the initial metadata decoding + self.decode_ss58 = decode_ss58 + self.runtime_config = RuntimeConfigurationObject(ss58_format=ss58_format) + self.ss58_format = ss58_format + self.runtime_config.update_type_registry(load_type_registry_preset(name="core")) + if type_registry_preset is not None: + type_registry_preset_dict = load_type_registry_preset( + name=type_registry_preset, use_remote_preset=use_remote_preset ) + if not type_registry_preset_dict: + raise ValueError( + f"Type registry preset '{type_registry_preset}' not found" + ) else: - return self.runtime.metadata + type_registry_preset_dict = None + + if type_registry_preset_dict: + self.runtime_config.update_type_registry( + load_type_registry_preset("legacy", use_remote_preset=use_remote_preset) + ) + if type_registry_preset != "legacy": + self.runtime_config.update_type_registry(type_registry_preset_dict) + if type_registry: + # Load type registries in runtime configuration + self.runtime_config.update_type_registry(type_registry) + + def _runtime_config_copy(self, implements_scale_info: bool = False): + runtime_config = RuntimeConfigurationObject( + ss58_format=self.ss58_format, implements_scale_info=implements_scale_info + ) + runtime_config.active_spec_version_id = ( + self.runtime_config.active_spec_version_id + ) + runtime_config.chain_id = self.runtime_config.chain_id + # TODO. This works, but deepcopy does not. Indicating this gets updated somewhere else. + runtime_config.type_registry = self.runtime_config.type_registry + assert runtime_config.type_registry == self.runtime_config.type_registry + return runtime_config @property - def implements_scaleinfo(self) -> Optional[bool]: + def chain(self): """ - Returns True if current runtime implementation a `PortableRegistry` (`MetadataV14` and higher) - - Returns - ------- - bool + Returns the substrate chain currently associated with object """ - if self.runtime and self.runtime.metadata: - return self.runtime.metadata.portable_registry is not None - else: - return None + return self._chain def ss58_encode( self, public_key: Union[str, bytes], ss58_format: int = None @@ -454,7 +647,11 @@ def is_valid_ss58_address(self, value: str) -> bool: return is_valid_ss58_address(value, valid_ss58_format=self.ss58_format) def serialize_storage_item( - self, storage_item: ScaleType, module, spec_version_id + self, + storage_item: ScaleType, + module: str, + spec_version_id: int, + runtime: Optional[Runtime] = None, ) -> dict: """ Helper function to serialize a storage item @@ -463,10 +660,17 @@ def serialize_storage_item( storage_item: the storage item to serialize module: the module to use to serialize the storage item spec_version_id: the version id + runtime: The runtime to serialize the storage item Returns: dict """ + if not runtime: + runtime = self.runtime + metadata = self.metadata + else: + metadata = runtime.metadata + storage_dict = { "storage_name": storage_item.name, "storage_modifier": storage_item.modifier, @@ -497,10 +701,10 @@ def serialize_storage_item( query_value = storage_item.value_object["default"].value_object try: - obj = self.runtime_config.create_scale_object( + obj = runtime.runtime_config.create_scale_object( type_string=value_scale_type, data=ScaleBytes(query_value), - metadata=self.metadata, + metadata=metadata, ) obj.decode() storage_dict["storage_default"] = obj.decode() @@ -622,183 +826,6 @@ def serialize_module_error(module, error, spec_version) -> dict: "spec_version": spec_version, } - def _load_registry_type_map(self, registry): - registry_type_map = {} - type_id_to_name = {} - types = json.loads(registry.registry)["types"] - type_by_id = {entry["id"]: entry for entry in types} - - # Pass 1: Gather simple types - for type_entry in types: - type_id = type_entry["id"] - type_def = type_entry["type"]["def"] - type_path = type_entry["type"].get("path") - if type_entry.get("params") or "variant" in type_def: - continue - if type_path: - type_name = type_path[-1] - registry_type_map[type_name] = type_id - type_id_to_name[type_id] = type_name - else: - # Possibly a primitive - if "primitive" in type_def: - prim_name = type_def["primitive"] - registry_type_map[prim_name] = type_id - type_id_to_name[type_id] = prim_name - - # Pass 2: Resolve remaining types - pending_ids = set(type_by_id.keys()) - set(type_id_to_name.keys()) - - def resolve_type_definition(type_id_): - type_entry_ = type_by_id[type_id_] - type_def_ = type_entry_["type"]["def"] - type_path_ = type_entry_["type"].get("path", []) - type_params = type_entry_["type"].get("params", []) - - if type_id_ in type_id_to_name: - return type_id_to_name[type_id_] - - # Resolve complex types with paths (including generics like Option etc) - if type_path_: - type_name_ = type_path_[-1] - if type_params: - inner_names = [] - for param in type_params: - dep_id = param["type"] - if dep_id not in type_id_to_name: - return None - inner_names.append(type_id_to_name[dep_id]) - return f"{type_name_}<{', '.join(inner_names)}>" - if "variant" in type_def_: - return None - return type_name_ - - elif "sequence" in type_def_: - sequence_type_id = type_def_["sequence"]["type"] - inner_type = type_id_to_name.get(sequence_type_id) - if inner_type: - type_name_ = f"Vec<{inner_type}>" - return type_name_ - - elif "array" in type_def_: - array_type_id = type_def_["array"]["type"] - inner_type = type_id_to_name.get(array_type_id) - maybe_len = type_def_["array"].get("len") - if inner_type: - if maybe_len: - type_name_ = f"[{inner_type}; {maybe_len}]" - else: - type_name_ = f"[{inner_type}]" - return type_name_ - - elif "compact" in type_def_: - compact_type_id = type_def_["compact"]["type"] - inner_type = type_id_to_name.get(compact_type_id) - if inner_type: - type_name_ = f"Compact<{inner_type}>" - return type_name_ - - elif "tuple" in type_def_: - tuple_type_ids = type_def_["tuple"] - type_names = [] - for inner_type_id in tuple_type_ids: - if inner_type_id not in type_id_to_name: - return None - type_names.append(type_id_to_name[inner_type_id]) - type_name_ = ", ".join(type_names) - type_name_ = f"({type_name_})" - return type_name_ - - elif "variant" in type_def_: - return None - - return None - - resolved_type = True - while resolved_type and pending_ids: - resolved_type = False - for type_id in list(pending_ids): - name = resolve_type_definition(type_id) - if name is not None: - type_id_to_name[type_id] = name - registry_type_map[name] = type_id - pending_ids.remove(type_id) - resolved_type = True - - self.registry_type_map = registry_type_map - self.type_id_to_name = type_id_to_name - - def reload_type_registry( - self, use_remote_preset: bool = True, auto_discover: bool = True - ): - """ - Reload type registry and preset used to instantiate the `AsyncSubstrateInterface` object. Useful to - periodically apply changes in type definitions when a runtime upgrade occurred - - Args: - use_remote_preset: When True preset is downloaded from Github master, - otherwise use files from local installed scalecodec package - auto_discover: Whether to automatically discover the type_registry - presets based on the chain name and typer registry - """ - self.runtime_config.clear_type_registry() - - self.runtime_config.implements_scale_info = self.implements_scaleinfo - - # Load metadata types in runtime configuration - self.runtime_config.update_type_registry(load_type_registry_preset(name="core")) - self.apply_type_registry_presets( - use_remote_preset=use_remote_preset, auto_discover=auto_discover - ) - - def apply_type_registry_presets( - self, use_remote_preset: bool = True, auto_discover: bool = True - ): - if self.type_registry_preset is not None: - # Load type registry according to preset - type_registry_preset_dict = load_type_registry_preset( - name=self.type_registry_preset, use_remote_preset=use_remote_preset - ) - - if not type_registry_preset_dict: - raise ValueError( - f"Type registry preset '{self.type_registry_preset}' not found" - ) - - elif auto_discover: - # Try to auto discover type registry preset by chain name - type_registry_name = self.chain.lower().replace(" ", "-") - try: - type_registry_preset_dict = load_type_registry_preset( - type_registry_name - ) - logger.debug( - f"Auto set type_registry_preset to {type_registry_name} ..." - ) - self.type_registry_preset = type_registry_name - except ValueError: - type_registry_preset_dict = None - - else: - type_registry_preset_dict = None - - if type_registry_preset_dict: - # Load type registries in runtime configuration - if self.implements_scaleinfo is False: - # Only runtime with no embedded types in metadata need the default set of explicit defined types - self.runtime_config.update_type_registry( - load_type_registry_preset( - "legacy", use_remote_preset=use_remote_preset - ) - ) - - if self.type_registry_preset != "legacy": - self.runtime_config.update_type_registry(type_registry_preset_dict) - - if self.type_registry: - # Load type registries in runtime configuration - self.runtime_config.update_type_registry(self.type_registry) - def extension_call(self, name, **kwargs): raise NotImplementedError( "Extensions not implemented in AsyncSubstrateInterface" @@ -836,13 +863,16 @@ def make_payload(id_: str, method: str, params: list) -> dict: "payload": {"jsonrpc": "2.0", "method": method, "params": params}, } - def _encode_scale(self, type_string, value: Any) -> bytes: + def _encode_scale( + self, type_string, value: Any, runtime: Optional[Runtime] = None + ) -> bytes: """ Helper function to encode arbitrary data into SCALE-bytes for given RUST type_string Args: type_string: the type string of the SCALE object for decoding value: value to encode + runtime: Optional Runtime whose registry to use for encoding Returns: encoded bytes @@ -850,14 +880,16 @@ def _encode_scale(self, type_string, value: Any) -> bytes: if value is None: result = b"\x00" else: + if not runtime: + runtime = self.runtime try: vec_acct_id = ( - f"scale_info::{self.registry_type_map['Vec']}" + f"scale_info::{runtime.registry_type_map['Vec']}" ) except KeyError: vec_acct_id = "scale_info::152" try: - optional_acct_u16 = f"scale_info::{self.registry_type_map['Option<(AccountId32, u16)>']}" + optional_acct_u16 = f"scale_info::{runtime.registry_type_map['Option<(AccountId32, u16)>']}" except KeyError: optional_acct_u16 = "scale_info::579" @@ -902,12 +934,11 @@ def _encode_scale(self, type_string, value: Any) -> bytes: else: value = value.value # Unwrap the value of the type - result = bytes( - encode_by_type_string(type_string, self.runtime.registry, value) - ) + result = bytes(encode_by_type_string(type_string, runtime.registry, value)) return result - def _encode_account_id(self, account) -> bytes: + @staticmethod + def _encode_account_id(account) -> bytes: """Encode an account ID into bytes. Args: diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index fa4be3c..23bbf9f 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -1,14 +1,13 @@ import asyncio +import inspect from collections import OrderedDict import functools +import logging import os import pickle import sqlite3 from pathlib import Path -from typing import Callable, Any - -import asyncstdlib as a - +from typing import Callable, Any, Awaitable, Hashable, Optional USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False CACHE_LOCATION = ( @@ -19,6 +18,8 @@ else ":memory:" ) +logger = logging.getLogger("async_substrate_interface") + def _ensure_dir(): path = Path(CACHE_LOCATION).parent @@ -70,7 +71,7 @@ def _retrieve_from_cache(c, table_name, key, chain): if result is not None: return pickle.loads(result[0]) except (pickle.PickleError, sqlite3.Error) as e: - print(f"Cache error: {str(e)}") + logger.exception("Cache error", exc_info=e) pass @@ -82,7 +83,7 @@ def _insert_into_cache(c, conn, table_name, key, result, chain): ) conn.commit() except (pickle.PickleError, sqlite3.Error) as e: - print(f"Cache error: {str(e)}") + logger.exception("Cache error", exc_info=e) pass @@ -128,7 +129,7 @@ def inner(self, *args, **kwargs): def async_sql_lru_cache(maxsize=None): def decorator(func): - @a.lru_cache(maxsize=maxsize) + @cached_fetcher(max_size=maxsize) async def inner(self, *args, **kwargs): c, conn, table_name, key, result, chain, local_chain = ( _shared_inner_fn_logic(func, self, args, kwargs) @@ -147,6 +148,10 @@ async def inner(self, *args, **kwargs): class LRUCache: + """ + Basic Least-Recently-Used Cache, with simple methods `set` and `get` + """ + def __init__(self, max_size: int): self.max_size = max_size self.cache = OrderedDict() @@ -167,31 +172,121 @@ def get(self, key): class CachedFetcher: - def __init__(self, max_size: int, method: Callable): - self._inflight: dict[int, asyncio.Future] = {} + """ + Async caching class that allows the standard async LRU cache system, but also allows for concurrent + asyncio calls (with the same args) to use the same result of a single call. + + This should only be used for asyncio calls where the result is immutable. + + Concept and usage: + ``` + async def fetch(self, block_hash: str) -> str: + return await some_resource(block_hash) + + a1, a2, b = await asyncio.gather(fetch("a"), fetch("a"), fetch("b")) + ``` + + Here, you are making three requests, but you really only need to make two I/O requests + (one for "a", one for "b"), and while you wouldn't typically make a request like this directly, it's very + common in using this library to inadvertently make these requests y gathering multiple resources that depend + on the calls like this under the hood. + + By using + + ``` + @cached_fetcher(max_size=512) + async def fetch(self, block_hash: str) -> str: + return await some_resource(block_hash) + + a1, a2, b = await asyncio.gather(fetch("a"), fetch("a"), fetch("b")) + ``` + + You are only making two I/O calls, and a2 will simply use the result of a1 when it lands. + """ + + def __init__( + self, + max_size: int, + method: Callable[..., Awaitable[Any]], + cache_key_index: Optional[int] = 0, + ): + """ + Args: + max_size: max size of the cache (in items) + method: the function to cache + cache_key_index: if the method takes multiple args, this is the index of that cache key in the args list + (default is the first arg). By setting this to `None`, it will use all args as the cache key. + """ + self._inflight: dict[Hashable, asyncio.Future] = {} self._method = method self._cache = LRUCache(max_size=max_size) + self._cache_key_index = cache_key_index - async def execute(self, single_arg: Any) -> str: - if item := self._cache.get(single_arg): + def make_cache_key(self, args: tuple, kwargs: dict) -> Hashable: + bound = inspect.signature(self._method).bind(*args, **kwargs) + bound.apply_defaults() + + if self._cache_key_index is not None: + key_name = list(bound.arguments)[self._cache_key_index] + return bound.arguments[key_name] + + return (tuple(bound.arguments.items()),) + + async def __call__(self, *args: Any, **kwargs: Any) -> Any: + key = self.make_cache_key(args, kwargs) + + if item := self._cache.get(key): return item - if single_arg in self._inflight: - result = await self._inflight[single_arg] - return result + if key in self._inflight: + return await self._inflight[key] loop = asyncio.get_running_loop() future = loop.create_future() - self._inflight[single_arg] = future + self._inflight[key] = future try: - result = await self._method(single_arg) - self._cache.set(single_arg, result) + result = await self._method(*args, **kwargs) + self._cache.set(key, result) future.set_result(result) return result except Exception as e: - # Propagate errors future.set_exception(e) raise finally: - self._inflight.pop(single_arg, None) + self._inflight.pop(key, None) + + +class _CachedFetcherMethod: + """ + Helper class for using CachedFetcher with method caches (rather than functions) + """ + + def __init__(self, method, max_size: int, cache_key_index: int): + self.method = method + self.max_size = max_size + self.cache_key_index = cache_key_index + self._instances = {} + + def __get__(self, instance, owner): + if instance is None: + return self + + # Cache per-instance + if instance not in self._instances: + bound_method = self.method.__get__(instance, owner) + self._instances[instance] = CachedFetcher( + max_size=self.max_size, + method=bound_method, + cache_key_index=self.cache_key_index, + ) + return self._instances[instance] + + +def cached_fetcher(max_size: int, cache_key_index: int = 0): + """Wrapper for CachedFetcher. See example in CachedFetcher docstring.""" + + def wrapper(method): + return _CachedFetcherMethod(method, max_size, cache_key_index) + + return wrapper diff --git a/async_substrate_interface/utils/decoding.py b/async_substrate_interface/utils/decoding.py index 6dc7f21..af8d969 100644 --- a/async_substrate_interface/utils/decoding.py +++ b/async_substrate_interface/utils/decoding.py @@ -1,8 +1,7 @@ -from typing import Union, TYPE_CHECKING +from typing import Union, TYPE_CHECKING, Any from bt_decode import AxonInfo, PrometheusInfo, decode_list -from scalecodec import ss58_encode -from bittensor_wallet.utils import SS58_FORMAT +from scalecodec import ScaleBytes, ss58_encode from async_substrate_interface.utils import hex_to_bytes from async_substrate_interface.types import ScaleObj @@ -57,10 +56,16 @@ def _bt_decode_to_dict_or_list(obj) -> Union[dict, list[dict]]: def _decode_scale_list_with_runtime( type_strings: list[str], scale_bytes_list: list[bytes], - runtime_registry, + runtime: "Runtime", return_scale_obj: bool = False, ): - obj = decode_list(type_strings, runtime_registry, scale_bytes_list) + if runtime.metadata_v15 is not None: + obj = decode_list(type_strings, runtime.registry, scale_bytes_list) + else: + obj = [ + legacy_scale_decode(x, y, runtime) + for (x, y) in zip(type_strings, scale_bytes_list) + ] if return_scale_obj: return [ScaleObj(x) for x in obj] else: @@ -68,7 +73,7 @@ def _decode_scale_list_with_runtime( def decode_query_map( - result_group_changes, + result_group_changes: list, prefix, runtime: "Runtime", param_types, @@ -76,6 +81,7 @@ def decode_query_map( value_type, key_hashers, ignore_decoding_errors, + decode_ss58: bool = False, ): def concat_hash_len(key_hasher: str) -> int: """ @@ -111,16 +117,25 @@ def concat_hash_len(key_hasher: str) -> int: all_decoded = _decode_scale_list_with_runtime( pre_decoded_key_types + pre_decoded_value_types, pre_decoded_keys + pre_decoded_values, - runtime.registry, + runtime, ) middl_index = len(all_decoded) // 2 decoded_keys = all_decoded[:middl_index] - decoded_values = [ScaleObj(x) for x in all_decoded[middl_index:]] - for dk, dv in zip(decoded_keys, decoded_values): + decoded_values = all_decoded[middl_index:] + for kts, vts, dk, dv in zip( + pre_decoded_key_types, + pre_decoded_value_types, + decoded_keys, + decoded_values, + ): try: # strip key_hashers to use as item key if len(param_types) - len(params) == 1: item_key = dk[1] + if decode_ss58: + if kts[kts.index(", ") + 2 : kts.index(")")] == "scale_info::0": + item_key = ss58_encode(bytes(item_key[0]), runtime.ss58_format) + else: item_key = tuple( dk[key + 1] for key in range(len(params), len(param_types) + 1, 2) @@ -130,7 +145,95 @@ def concat_hash_len(key_hasher: str) -> int: if not ignore_decoding_errors: raise item_key = None - item_value = dv - result.append([item_key, item_value]) + if decode_ss58: + try: + value_type_str_int = int(vts.split("::")[1]) + decoded_type_str = runtime.type_id_to_name[value_type_str_int] + item_value = convert_account_ids( + dv, decoded_type_str, runtime.ss58_format + ) + except (ValueError, KeyError): + pass + result.append([item_key, ScaleObj(item_value)]) return result + + +def legacy_scale_decode( + type_string: str, scale_bytes: Union[str, ScaleBytes], runtime: "Runtime" +): + if isinstance(scale_bytes, (str, bytes)): + scale_bytes = ScaleBytes(scale_bytes) + + obj = runtime.runtime_config.create_scale_object( + type_string=type_string, data=scale_bytes, metadata=runtime.metadata + ) + + obj.decode(check_remaining=runtime.config.get("strict_scale_decode")) + + return obj.value + + +def is_accountid32(value: Any) -> bool: + return ( + isinstance(value, tuple) + and len(value) == 32 + and all(isinstance(b, int) and 0 <= b <= 255 for b in value) + ) + + +def convert_account_ids(value: Any, type_str: str, ss58_format=42) -> Any: + if "AccountId32" not in type_str: + return value + + # Option + if type_str.startswith("Option<") and value is not None: + inner_type = type_str[7:-1] + return convert_account_ids(value, inner_type) + # Vec + if type_str.startswith("Vec<") and isinstance(value, (list, tuple)): + inner_type = type_str[4:-1] + return tuple(convert_account_ids(v, inner_type) for v in value) + + # Vec> + if type_str.startswith("Vec list[str]: + """ + Splits a type string like '(AccountId32, Vec)' into ['AccountId32', 'Vec'] + Handles nested generics. + """ + s = type_str[1:-1] + parts = [] + depth = 0 + current = "" + for char in s: + if char == "," and depth == 0: + parts.append(current.strip()) + current = "" + else: + if char == "<": + depth += 1 + elif char == ">": + depth -= 1 + current += char + if current: + parts.append(current.strip()) + return parts diff --git a/pyproject.toml b/pyproject.toml index 389ea1a..80f78b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.3.1" +version = "1.4.0" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } @@ -8,7 +8,6 @@ keywords = ["substrate", "development", "bittensor"] dependencies = [ "wheel", - "asyncstdlib~=3.13.0", "bt-decode==v0.6.0", "scalecodec~=1.2.11", "websockets>=14.1", diff --git a/tests/helpers/settings.py b/tests/helpers/settings.py index ae1d7cb..0e9e1da 100644 --- a/tests/helpers/settings.py +++ b/tests/helpers/settings.py @@ -32,3 +32,7 @@ AURA_NODE_URL = ( environ.get("SUBSTRATE_AURA_NODE_URL") or "wss://acala-rpc-1.aca-api.network" ) + +ARCHIVE_ENTRYPOINT = "wss://archive.chain.opentensor.ai:443" + +LATENT_LITE_ENTRYPOINT = "wss://lite.sub.latent.to:443" diff --git a/tests/integration_tests/test_async_substrate_interface.py b/tests/integration_tests/test_async_substrate_interface.py new file mode 100644 index 0000000..2c50213 --- /dev/null +++ b/tests/integration_tests/test_async_substrate_interface.py @@ -0,0 +1,134 @@ +import time + +import pytest +from scalecodec import ss58_encode + +from async_substrate_interface.async_substrate import AsyncSubstrateInterface +from async_substrate_interface.types import ScaleObj +from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT + + +@pytest.mark.asyncio +async def test_legacy_decoding(): + # roughly 4000 blocks before metadata v15 was added + pre_metadata_v15_block = 3_010_611 + + async with AsyncSubstrateInterface(ARCHIVE_ENTRYPOINT) as substrate: + block_hash = await substrate.get_block_hash(pre_metadata_v15_block) + events = await substrate.get_events(block_hash) + assert isinstance(events, list) + + query_map_result = await substrate.query_map( + module="SubtensorModule", + storage_function="NetworksAdded", + block_hash=block_hash, + ) + async for key, value in query_map_result: + assert isinstance(key, int) + assert isinstance(value, ScaleObj) + + timestamp = await substrate.query( + "Timestamp", + "Now", + block_hash=block_hash, + ) + assert timestamp.value == 1716358476004 + + +@pytest.mark.asyncio +async def test_ss58_conversion(): + async with AsyncSubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, decode_ss58=False + ) as substrate: + block_hash = await substrate.get_chain_finalised_head() + qm = await substrate.query_map( + "SubtensorModule", + "OwnedHotkeys", + block_hash=block_hash, + ) + # only do the first page, bc otherwise this will be massive + for key, value in qm.records: + assert isinstance(key, tuple) + assert isinstance(value, ScaleObj) + assert isinstance(value.value, list) + assert len(key) == 1 + for key_tuple in value.value: + assert len(key_tuple[0]) == 32 + random_key = key_tuple[0] + + ss58_of_key = ss58_encode(bytes(random_key), substrate.ss58_format) + assert isinstance(ss58_of_key, str) + + substrate.decode_ss58 = True # change to decoding True + + qm = await substrate.query_map( + "SubtensorModule", + "OwnedHotkeys", + block_hash=block_hash, + ) + for key, value in qm.records: + assert isinstance(key, str) + assert isinstance(value, ScaleObj) + assert isinstance(value.value, list) + if len(value.value) > 0: + for decoded_key in value.value: + assert isinstance(decoded_key, str) + + +@pytest.mark.asyncio +async def test_fully_exhaust_query_map(): + async with AsyncSubstrateInterface(LATENT_LITE_ENTRYPOINT) as substrate: + block_hash = await substrate.get_chain_finalised_head() + non_fully_exhauster_start = time.time() + non_fully_exhausted_qm = await substrate.query_map( + "SubtensorModule", + "CRV3WeightCommits", + block_hash=block_hash, + ) + initial_records_count = len(non_fully_exhausted_qm.records) + assert initial_records_count <= 100 # default page size + exhausted_records_count = 0 + async for _ in non_fully_exhausted_qm: + exhausted_records_count += 1 + non_fully_exhausted_time = time.time() - non_fully_exhauster_start + + assert len(non_fully_exhausted_qm.records) >= initial_records_count + fully_exhausted_start = time.time() + fully_exhausted_qm = await substrate.query_map( + "SubtensorModule", + "CRV3WeightCommits", + block_hash=block_hash, + fully_exhaust=True, + ) + + fully_exhausted_time = time.time() - fully_exhausted_start + initial_records_count_fully_exhaust = len(fully_exhausted_qm.records) + assert fully_exhausted_time <= non_fully_exhausted_time, ( + f"Fully exhausted took longer than non-fully exhausted with " + f"{len(non_fully_exhausted_qm.records)} records in non-fully exhausted " + f"in {non_fully_exhausted_time} seconds, and {initial_records_count_fully_exhaust} in fully exhausted" + f" in {fully_exhausted_time} seconds. This could be caused by the fact that on this specific block, " + f"there are fewer records than take up a single page. This difference should still be small." + ) + fully_exhausted_records_count = 0 + async for _ in fully_exhausted_qm: + fully_exhausted_records_count += 1 + assert fully_exhausted_records_count == initial_records_count_fully_exhaust + assert initial_records_count_fully_exhaust == exhausted_records_count + + +@pytest.mark.asyncio +async def test_get_events_proper_decoding(): + # known block/hash pair that has the events we seek to decode + block = 5846788 + block_hash = "0x0a1c45063a59b934bfee827caa25385e60d5ec1fd8566a58b5cc4affc4eec412" + + async with AsyncSubstrateInterface(ARCHIVE_ENTRYPOINT) as substrate: + all_events = await substrate.get_events(block_hash=block_hash) + event = all_events[1] + print(type(event["attributes"])) + assert event["attributes"] == ( + "5G1NjW9YhXLadMWajvTkfcJy6up3yH2q1YzMXDTi6ijanChe", + 30, + "0xa6b4e5c8241d60ece0c25056b19f7d21ae845269fc771ad46bf3e011865129a5", + ) diff --git a/tests/integration_tests/test_substrate_interface.py b/tests/integration_tests/test_substrate_interface.py new file mode 100644 index 0000000..be4eb29 --- /dev/null +++ b/tests/integration_tests/test_substrate_interface.py @@ -0,0 +1,86 @@ +from scalecodec import ss58_encode + +from async_substrate_interface.sync_substrate import SubstrateInterface +from async_substrate_interface.types import ScaleObj +from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT + + +def test_legacy_decoding(): + # roughly 4000 blocks before metadata v15 was added + pre_metadata_v15_block = 3_010_611 + + with SubstrateInterface(ARCHIVE_ENTRYPOINT) as substrate: + block_hash = substrate.get_block_hash(pre_metadata_v15_block) + events = substrate.get_events(block_hash) + assert isinstance(events, list) + + query_map_result = substrate.query_map( + module="SubtensorModule", + storage_function="NetworksAdded", + block_hash=block_hash, + ) + for key, value in query_map_result: + assert isinstance(key, int) + assert isinstance(value, ScaleObj) + + timestamp = substrate.query( + "Timestamp", + "Now", + block_hash=block_hash, + ) + assert timestamp.value == 1716358476004 + + +def test_ss58_conversion(): + with SubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, decode_ss58=False + ) as substrate: + block_hash = substrate.get_chain_finalised_head() + qm = substrate.query_map( + "SubtensorModule", + "OwnedHotkeys", + block_hash=block_hash, + ) + # only do the first page, bc otherwise this will be massive + for key, value in qm.records: + assert isinstance(key, tuple) + assert isinstance(value, ScaleObj) + assert isinstance(value.value, list) + assert len(key) == 1 + for key_tuple in value.value: + assert len(key_tuple[0]) == 32 + random_key = key_tuple[0] + + ss58_of_key = ss58_encode(bytes(random_key), substrate.ss58_format) + assert isinstance(ss58_of_key, str) + + substrate.decode_ss58 = True # change to decoding True + + qm = substrate.query_map( + "SubtensorModule", + "OwnedHotkeys", + block_hash=block_hash, + ) + for key, value in qm.records: + assert isinstance(key, str) + assert isinstance(value, ScaleObj) + assert isinstance(value.value, list) + if len(value.value) > 0: + for decoded_key in value.value: + assert isinstance(decoded_key, str) + + +def test_get_events_proper_decoding(): + # known block/hash pair that has the events we seek to decode + block = 5846788 + block_hash = "0x0a1c45063a59b934bfee827caa25385e60d5ec1fd8566a58b5cc4affc4eec412" + + with SubstrateInterface(ARCHIVE_ENTRYPOINT) as substrate: + all_events = substrate.get_events(block_hash=block_hash) + event = all_events[1] + print(type(event["attributes"])) + assert event["attributes"] == ( + "5G1NjW9YhXLadMWajvTkfcJy6up3yH2q1YzMXDTi6ijanChe", + 30, + "0xa6b4e5c8241d60ece0c25056b19f7d21ae845269fc771ad46bf3e011865129a5", + ) diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index a64d570..1ea30ef 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, ANY import pytest from websockets.exceptions import InvalidURI @@ -64,7 +64,7 @@ async def test_runtime_call(monkeypatch): # Patch RPC request with correct behavior substrate.rpc_request = AsyncMock( - side_effect=lambda method, params: { + side_effect=lambda method, params, runtime: { "result": "0x00" if method == "state_call" else {"parentHash": "0xDEADBEEF"} } ) @@ -83,14 +83,16 @@ async def test_runtime_call(monkeypatch): assert result.value == "decoded_result" # Check decode_scale called correctly - substrate.decode_scale.assert_called_once_with("scale_info::1", b"\x00") + substrate.decode_scale.assert_called_once_with( + "scale_info::1", b"\x00", runtime=ANY + ) # encode_scale should not be called since no inputs substrate.encode_scale.assert_not_called() # Check RPC request called for the state_call substrate.rpc_request.assert_any_call( - "state_call", ["SubstrateApi_SubstrateMethod", "", None] + "state_call", ["SubstrateApi_SubstrateMethod", "", None], runtime=ANY ) diff --git a/tests/unit_tests/sync/test_substrate_interface.py b/tests/unit_tests/sync/test_substrate_interface.py index 6d9c471..ea6d7b5 100644 --- a/tests/unit_tests/sync/test_substrate_interface.py +++ b/tests/unit_tests/sync/test_substrate_interface.py @@ -72,3 +72,4 @@ def test_runtime_call(monkeypatch): substrate.rpc_request.assert_any_call( "state_call", ["SubstrateApi_SubstrateMethod", "", None] ) + substrate.close() diff --git a/tests/unit_tests/test_cache.py b/tests/unit_tests/test_cache.py index 7844202..726c94c 100644 --- a/tests/unit_tests/test_cache.py +++ b/tests/unit_tests/test_cache.py @@ -13,18 +13,18 @@ async def test_cached_fetcher_fetches_and_caches(): fetcher = CachedFetcher(max_size=2, method=mock_method) # First call should trigger the method - result1 = await fetcher.execute("key1") + result1 = await fetcher("key1") assert result1 == "result_key1" mock_method.assert_awaited_once_with("key1") # Second call with the same key should use the cache - result2 = await fetcher.execute("key1") + result2 = await fetcher("key1") assert result2 == "result_key1" # Ensure the method was NOT called again assert mock_method.await_count == 1 # Third call with a new key triggers a method call - result3 = await fetcher.execute("key2") + result3 = await fetcher("key2") assert result3 == "result_key2" assert mock_method.await_count == 2 @@ -42,11 +42,11 @@ async def slow_method(x): fetcher = CachedFetcher(max_size=2, method=slow_method) # Start first request - task1 = asyncio.create_task(fetcher.execute("key1")) + task1 = asyncio.create_task(fetcher("key1")) await asyncio.sleep(0.1) # Let the task start and be inflight # Second request for the same key while the first is in-flight - task2 = asyncio.create_task(fetcher.execute("key1")) + task2 = asyncio.create_task(fetcher("key1")) await asyncio.sleep(0.1) # Release the inflight request @@ -65,22 +65,25 @@ async def error_method(x): fetcher = CachedFetcher(max_size=2, method=error_method) with pytest.raises(ValueError, match="Boom!"): - await fetcher.execute("key1") + await fetcher("key1") @pytest.mark.asyncio async def test_cached_fetcher_eviction(): """Tests that LRU eviction works in CachedFetcher.""" - mock_method = mock.AsyncMock(side_effect=lambda x: f"val_{x}") - fetcher = CachedFetcher(max_size=2, method=mock_method) + + async def side_effect_method(x): + return f"val_{x}" + + fetcher = CachedFetcher(max_size=2, method=side_effect_method) # Fill cache - await fetcher.execute("key1") - await fetcher.execute("key2") + await fetcher("key1") + await fetcher("key2") assert list(fetcher._cache.cache.keys()) == list(fetcher._cache.cache.keys()) # Insert a new key to trigger eviction - await fetcher.execute("key3") + await fetcher("key3") # key1 should be evicted assert "key1" not in fetcher._cache.cache assert "key2" in fetcher._cache.cache