diff --git a/CHANGELOG.md b/CHANGELOG.md index 3454ca7..009945c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## 1.2.0 /2025-05-07 + +## What's Changed +* Add missing methods by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/104 +* Max subscriptions semaphore added by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/107 +* Expose `_get_block_handler` publicly by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/108 +* safe `__del__` by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/110 +* Tensorshield/main by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/111 +* Support async key implementations by @immortalizzy in https://github.com/opentensor/async-substrate-interface/pull/94 +* Add MetadataAtVersionNotFound error by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/113 +* Fallback chains by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/100 + +## New Contributors +* @immortalizzy made their first contribution in https://github.com/opentensor/async-substrate-interface/pull/94 + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.1.1...v1.2.0 + ## 1.1.1 /2025-04-26 ## What's Changed diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 94abf59..94bda13 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -21,8 +21,6 @@ ) import asyncstdlib as a -from bittensor_wallet.keypair import Keypair -from bittensor_wallet.utils import SS58_FORMAT from bt_decode import MetadataV15, PortableRegistry, decode as decode_by_type_string from scalecodec.base import ScaleBytes, ScaleType, RuntimeConfigurationObject from scalecodec.types import ( @@ -30,16 +28,20 @@ GenericExtrinsic, GenericRuntimeCallDefinition, ss58_encode, + MultiAccountId, ) from websockets.asyncio.client import connect from websockets.exceptions import ConnectionClosed +from async_substrate_interface.const import SS58_FORMAT from async_substrate_interface.errors import ( SubstrateRequestException, ExtrinsicNotFound, BlockNotFound, MaxRetriesExceeded, + MetadataAtVersionNotFound, ) +from async_substrate_interface.protocols import Keypair from async_substrate_interface.types import ( ScaleObj, RequestManager, @@ -516,7 +518,7 @@ def __init__( # TODO reconnection logic self.ws_url = ws_url self.ws: Optional["ClientConnection"] = None - self.max_subscriptions = max_subscriptions + self.max_subscriptions = asyncio.Semaphore(max_subscriptions) self.max_connections = max_connections self.shutdown_timer = shutdown_timer self._received = {} @@ -631,6 +633,7 @@ async def send(self, payload: dict) -> int: # async with self._lock: original_id = get_next_id() # self._open_subscriptions += 1 + await self.max_subscriptions.acquire() try: await self.ws.send(json.dumps({**payload, **{"id": original_id}})) return original_id @@ -649,7 +652,9 @@ async def retrieve(self, item_id: int) -> Optional[dict]: retrieved item """ try: - return self._received.pop(item_id) + item = self._received.pop(item_id) + self.max_subscriptions.release() + return item except KeyError: await asyncio.sleep(0.001) return None @@ -730,15 +735,14 @@ async def initialize(self): """ Initialize the connection to the chain. """ - async with self._lock: - self._initializing = True - if not self.initialized: - if not self._chain: - chain = await self.rpc_request("system_chain", []) - self._chain = chain.get("result") - await self.init_runtime() - self.initialized = True - self._initializing = False + self._initializing = True + if not self.initialized: + if not self._chain: + chain = await self.rpc_request("system_chain", []) + self._chain = chain.get("result") + await self.init_runtime() + self.initialized = True + self._initializing = False async def __aexit__(self, exc_type, exc_val, exc_tb): pass @@ -813,10 +817,7 @@ async def _load_registry_at_block( "Client error: Execution failed: Other: Exported method Metadata_metadata_at_version is not found" in e.args ): - raise SubstrateRequestException( - "You are attempting to call a block too old for this version of async-substrate-interface. Please" - " instead use legacy py-substrate-interface for these very old blocks." - ) + raise MetadataAtVersionNotFound else: raise e metadata_option_hex_str = metadata_rpc_result["result"] @@ -876,7 +877,7 @@ async def decode_scale( scale_bytes: bytes, _attempt=1, _retries=3, - return_scale_obj=False, + return_scale_obj: bool = False, ) -> Union[ScaleObj, Any]: """ Helper function to decode arbitrary SCALE-bytes (e.g. 0x02000000) according to given RUST type_string @@ -1039,6 +1040,130 @@ async def create_storage_key( metadata=self.runtime.metadata, ) + async def subscribe_storage( + self, + storage_keys: list[StorageKey], + subscription_handler: Callable[[StorageKey, Any, str], Awaitable[Any]], + ): + """ + + Subscribe to provided storage_keys and keep tracking until `subscription_handler` returns a value + + Example of a StorageKey: + ``` + StorageKey.create_from_storage_function( + "System", "Account", ["5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY"] + ) + ``` + + Example of a subscription handler: + ``` + async def subscription_handler(storage_key, obj, subscription_id): + if obj is not None: + # the subscription will run until your subscription_handler returns something other than `None` + return obj + ``` + + Args: + storage_keys: StorageKey list of storage keys to subscribe to + subscription_handler: coroutine function to handle value changes of subscription + + """ + await self.init_runtime() + + storage_key_map = {s.to_hex(): s for s in storage_keys} + + async def result_handler( + message: dict, subscription_id: str + ) -> tuple[bool, Optional[Any]]: + result_found = False + subscription_result = None + if "params" in message: + # Process changes + for change_storage_key, change_data in message["params"]["result"][ + "changes" + ]: + # Check for target storage key + storage_key = storage_key_map[change_storage_key] + + if change_data is not None: + change_scale_type = storage_key.value_scale_type + result_found = True + elif ( + storage_key.metadata_storage_function.value["modifier"] + == "Default" + ): + # Fallback to default value of storage function if no result + change_scale_type = storage_key.value_scale_type + change_data = ( + storage_key.metadata_storage_function.value_object[ + "default" + ].value_object + ) + else: + # No result is interpreted as an Option<...> result + change_scale_type = f"Option<{storage_key.value_scale_type}>" + change_data = ( + storage_key.metadata_storage_function.value_object[ + "default" + ].value_object + ) + + # Decode SCALE result data + updated_obj = await self.decode_scale( + type_string=change_scale_type, + scale_bytes=hex_to_bytes(change_data), + ) + + subscription_result = await subscription_handler( + storage_key, updated_obj, subscription_id + ) + + if subscription_result is not None: + # Handler returned end result: unsubscribe from further updates + self._forgettable_task = asyncio.create_task( + self.rpc_request( + "state_unsubscribeStorage", [subscription_id] + ) + ) + + return result_found, subscription_result + + if not callable(subscription_handler): + raise ValueError("Provided `subscription_handler` is not callable") + + return await self.rpc_request( + "state_subscribeStorage", + [[s.to_hex() for s in storage_keys]], + result_handler=result_handler, + ) + + async def retrieve_pending_extrinsics(self) -> list: + """ + Retrieves and decodes pending extrinsics from the node's transaction pool + + Returns: + list of extrinsics + """ + + runtime = await self.init_runtime() + + result_data = await self.rpc_request("author_pendingExtrinsics", []) + + extrinsics = [] + + for extrinsic_data in result_data["result"]: + extrinsic = runtime.runtime_config.create_scale_object( + "Extrinsic", metadata=runtime.metadata + ) + extrinsic.decode( + ScaleBytes(extrinsic_data), + check_remaining=self.config.get("strict_scale_decode"), + ) + extrinsics.append(extrinsic) + + return extrinsics + async def get_metadata_storage_functions(self, block_hash=None) -> list: """ Retrieves a list of all storage functions in metadata active at given block_hash (or chaintip if block_hash is @@ -1193,6 +1318,41 @@ async def get_metadata_runtime_call_function( return runtime_call_def_obj + async def get_metadata_runtime_call_function( + self, api: str, method: str + ) -> GenericRuntimeCallDefinition: + """ + Get details of a runtime API call + + Args: + api: Name of the runtime API e.g. 'TransactionPaymentApi' + method: Name of the method e.g. 'query_fee_details' + + Returns: + GenericRuntimeCallDefinition + """ + await self.init_runtime(block_hash=block_hash) + + try: + runtime_call_def = self.runtime_config.type_registry["runtime_api"][api][ + "methods" + ][method] + runtime_call_def["api"] = api + runtime_call_def["method"] = method + runtime_api_types = self.runtime_config.type_registry["runtime_api"][ + api + ].get("types", {}) + except KeyError: + raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") + + # Add runtime API types to registry + self.runtime_config.update_type_registry_types(runtime_api_types) + + runtime_call_def_obj = await self.create_scale_object("RuntimeCallDefinition") + runtime_call_def_obj.encode(runtime_call_def) + + return runtime_call_def_obj + async def _get_block_handler( self, block_hash: str, @@ -1394,6 +1554,8 @@ async def result_handler( response["result"]["block"], block_data_hash=block_hash ) + get_block_handler = _get_block_handler + async def get_block( self, block_hash: Optional[str] = None, @@ -1669,6 +1831,21 @@ def convert_event_data(data): events.append(convert_event_data(item)) return events + async def get_metadata(self, block_hash=None) -> MetadataV15: + """ + Returns `MetadataVersioned` object for given block_hash or chaintip if block_hash is omitted + + + Args: + block_hash + + Returns: + MetadataVersioned + """ + runtime = await self.init_runtime(block_hash=block_hash) + + return runtime.metadata_v15 + @a.lru_cache(maxsize=512) async def get_parent_block_hash(self, block_hash): return await self._get_parent_block_hash(block_hash) @@ -1685,10 +1862,43 @@ async def _get_parent_block_hash(self, block_hash): return block_hash return parent_block_hash + async def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any: + """ + A pass-though to existing JSONRPC method `state_getStorage`/`state_getStorageAt` + + Args: + block_hash: hash of the block + storage_key: storage key to query + + Returns: + result of the query + + """ + + if await self.supports_rpc_method("state_getStorageAt"): + response = await self.rpc_request( + "state_getStorageAt", [storage_key, block_hash] + ) + else: + response = await self.rpc_request( + "state_getStorage", [storage_key, block_hash] + ) + + if "result" in response: + return response.get("result") + elif "error" in response: + raise SubstrateRequestException(response["error"]["message"]) + else: + raise SubstrateRequestException( + "Unknown error occurred during retrieval of events" + ) + @a.lru_cache(maxsize=16) async def get_block_runtime_info(self, block_hash: str) -> dict: return await self._get_block_runtime_info(block_hash) + get_block_runtime_version = get_block_runtime_info + async def _get_block_runtime_info(self, block_hash: str) -> dict: """ Retrieve the runtime info of given block_hash @@ -2415,6 +2625,8 @@ async def create_signed_extrinsic( # Sign payload signature = keypair.sign(signature_payload) + if inspect.isawaitable(signature): + signature = await signature # Create extrinsic extrinsic = self.runtime_config.create_scale_object( @@ -2443,6 +2655,34 @@ async def create_signed_extrinsic( return extrinsic + async def create_unsigned_extrinsic(self, call: GenericCall) -> GenericExtrinsic: + """ + Create unsigned extrinsic for given `Call` + + Args: + call: GenericCall the call the extrinsic should contain + + Returns: + GenericExtrinsic + """ + + runtime = await self.init_runtime() + + # Create extrinsic + extrinsic = self.runtime_config.create_scale_object( + type_string="Extrinsic", metadata=runtime.metadata + ) + + extrinsic.encode( + { + "call_function": call.value["call_function"], + "call_module": call.value["call_module"], + "call_args": call.value["call_args"], + } + ) + + return extrinsic + async def get_chain_finalised_head(self): """ A pass-though to existing JSONRPC method `chain_getFinalizedHead` @@ -2528,13 +2768,13 @@ async def runtime_call( Returns: ScaleType from the runtime call """ - await self.init_runtime(block_hash=block_hash) + runtime = await self.init_runtime(block_hash=block_hash) if params is None: params = {} try: - metadata_v15_value = self.runtime.metadata_v15.value() + metadata_v15_value = runtime.metadata_v15.value() apis = {entry["name"]: entry for entry in metadata_v15_value["apis"]} api_entry = apis[api] @@ -2627,6 +2867,29 @@ async def get_account_next_index(self, account_address: str) -> int: self._nonces[account_address] += 1 return self._nonces[account_address] + async def get_metadata_constants(self, block_hash=None) -> list[dict]: + """ + Retrieves a list of all constants in metadata active at given block_hash (or chaintip if block_hash is omitted) + + Args: + block_hash: hash of the block + + Returns: + list of constants + """ + + runtime = await self.init_runtime(block_hash=block_hash) + + constant_list = [] + + for module_idx, module in enumerate(self.metadata.pallets): + for constant in module.constants or []: + constant_list.append( + self.serialize_constant(constant, module, runtime.runtime_version) + ) + + return constant_list + async def get_metadata_constant(self, module_name, constant_name, block_hash=None): """ Retrieves the details of a constant for given module name, call function name and block_hash @@ -2994,6 +3257,100 @@ async def query_map( ignore_decoding_errors=ignore_decoding_errors, ) + async def create_multisig_extrinsic( + self, + call: GenericCall, + keypair: Keypair, + multisig_account: MultiAccountId, + max_weight: Optional[Union[dict, int]] = None, + era: dict = None, + nonce: int = None, + tip: int = 0, + tip_asset_id: int = None, + signature: Union[bytes, str] = None, + ) -> GenericExtrinsic: + """ + Create a Multisig extrinsic that will be signed by one of the signatories. Checks on-chain if the threshold + of the multisig account is reached and try to execute the call accordingly. + + Args: + call: GenericCall to create extrinsic for + keypair: Keypair of the signatory to approve given call + multisig_account: MultiAccountId to use of origin of the extrinsic (see `generate_multisig_account()`) + max_weight: Maximum allowed weight to execute the call ( Uses `get_payment_info()` by default) + era: Specify mortality in blocks in follow format: {'period': [amount_blocks]} If omitted the extrinsic is + immortal + nonce: nonce to include in extrinsics, if omitted the current nonce is retrieved on-chain + tip: The tip for the block author to gain priority during network congestion + tip_asset_id: Optional asset ID with which to pay the tip + signature: Optionally provide signature if externally signed + + Returns: + GenericExtrinsic + """ + if max_weight is None: + payment_info = await self.get_payment_info(call, keypair) + max_weight = payment_info["weight"] + + # Check if call has existing approvals + multisig_details_ = await self.query( + "Multisig", "Multisigs", [multisig_account.value, call.call_hash] + ) + multisig_details = getattr(multisig_details_, "value", multisig_details_) + if multisig_details: + maybe_timepoint = multisig_details["when"] + else: + maybe_timepoint = None + + # Compose 'as_multi' when final, 'approve_as_multi' otherwise + if ( + multisig_details.value + and len(multisig_details.value["approvals"]) + 1 + == multisig_account.threshold + ): + multi_sig_call = await self.compose_call( + "Multisig", + "as_multi", + { + "other_signatories": [ + s + for s in multisig_account.signatories + if s != f"0x{keypair.public_key.hex()}" + ], + "threshold": multisig_account.threshold, + "maybe_timepoint": maybe_timepoint, + "call": call, + "store_call": False, + "max_weight": max_weight, + }, + ) + else: + multi_sig_call = await self.compose_call( + "Multisig", + "approve_as_multi", + { + "other_signatories": [ + s + for s in multisig_account.signatories + if s != f"0x{keypair.public_key.hex()}" + ], + "threshold": multisig_account.threshold, + "maybe_timepoint": maybe_timepoint, + "call_hash": call.call_hash, + "max_weight": max_weight, + }, + ) + + return await self.create_signed_extrinsic( + multi_sig_call, + keypair, + era=era, + nonce=nonce, + tip=tip, + tip_asset_id=tip_asset_id, + signature=signature, + ) + async def submit_extrinsic( self, extrinsic: GenericExtrinsic, @@ -3136,6 +3493,55 @@ async def get_metadata_call_function( return call return None + async def get_metadata_events(self, block_hash=None) -> list[dict]: + """ + Retrieves a list of all events in metadata active for given block_hash (or chaintip if block_hash is omitted) + + Args: + block_hash + + Returns: + list of module events + """ + + runtime = await self.init_runtime(block_hash=block_hash) + + event_list = [] + + for event_index, (module, event) in self.metadata.event_index.items(): + event_list.append( + self.serialize_module_event( + module, event, runtime.runtime_version, event_index + ) + ) + + return event_list + + async def get_metadata_event( + self, module_name, event_name, block_hash=None + ) -> Optional[Any]: + """ + Retrieves the details of an event for given module name, call function name and block_hash + (or chaintip if block_hash is omitted) + + Args: + module_name: name of the module to call + event_name: name of the event + block_hash: hash of the block + + Returns: + Metadata event + + """ + + runtime = await self.init_runtime(block_hash=block_hash) + + for pallet in runtime.metadata.pallets: + if pallet.name == module_name and pallet.events: + for event in pallet.events: + if event.name == event_name: + return event + async def get_block_number(self, block_hash: Optional[str] = None) -> int: """Async version of `substrateinterface.base.get_block_number` method.""" response = await self.rpc_request("chain_getHeader", [block_hash]) diff --git a/async_substrate_interface/const.py b/async_substrate_interface/const.py new file mode 100644 index 0000000..983f9e4 --- /dev/null +++ b/async_substrate_interface/const.py @@ -0,0 +1,2 @@ +# Re-define SS58 format here to remove unnecessary dependencies. +SS58_FORMAT = 42 diff --git a/async_substrate_interface/errors.py b/async_substrate_interface/errors.py index 98114fe..c6a2d8d 100644 --- a/async_substrate_interface/errors.py +++ b/async_substrate_interface/errors.py @@ -12,6 +12,16 @@ class MaxRetriesExceeded(SubstrateRequestException): pass +class MetadataAtVersionNotFound(SubstrateRequestException): + def __init__(self): + message = ( + "Exported method Metadata_metadata_at_version is not found. This indicates the block is quite old, and is" + "not supported by async-substrate-interface. If you need this, we recommend using the legacy " + "substrate-interface (https://github.com/JAMdotTech/py-polkadot-sdk)." + ) + super().__init__(message) + + class StorageFunctionNotFound(ValueError): pass diff --git a/async_substrate_interface/protocols.py b/async_substrate_interface/protocols.py new file mode 100644 index 0000000..b50605f --- /dev/null +++ b/async_substrate_interface/protocols.py @@ -0,0 +1,36 @@ +from typing import Awaitable, Protocol, Union, Optional, runtime_checkable + + +__all__: list[str] = ["Keypair"] + + +# For reference only +# class KeypairType: +# """ +# Type of cryptography, used in `Keypair` instance to encrypt and sign data +# +# * ED25519 = 0 +# * SR25519 = 1 +# * ECDSA = 2 +# +# """ +# ED25519 = 0 +# SR25519 = 1 +# ECDSA = 2 + + +@runtime_checkable +class Keypair(Protocol): + @property + def crypto_type(self) -> int: ... + + @property + def public_key(self) -> Optional[bytes]: ... + + @property + def ss58_address(self) -> str: ... + + @property + def ss58_format(self) -> int: ... + + def sign(self, data: Union[bytes, str]) -> Union[bytes, Awaitable[bytes]]: ... diff --git a/async_substrate_interface/substrate_addons.py b/async_substrate_interface/substrate_addons.py new file mode 100644 index 0000000..5edb26a --- /dev/null +++ b/async_substrate_interface/substrate_addons.py @@ -0,0 +1,325 @@ +""" +A number of "plugins" for SubstrateInterface (and AsyncSubstrateInterface). At initial creation, it contains only +Retry (sync and async versions). +""" + +import asyncio +import logging +import socket +from functools import partial +from itertools import cycle +from typing import Optional + +from websockets.exceptions import ConnectionClosed + +from async_substrate_interface.async_substrate import AsyncSubstrateInterface, Websocket +from async_substrate_interface.errors import MaxRetriesExceeded +from async_substrate_interface.sync_substrate import SubstrateInterface + +logger = logging.getLogger("async_substrate_interface") + + +RETRY_METHODS = [ + "_get_block_handler", + "close", + "compose_call", + "create_scale_object", + "create_signed_extrinsic", + "create_storage_key", + "decode_scale", + "encode_scale", + "generate_signature_payload", + "get_account_next_index", + "get_account_nonce", + "get_block", + "get_block_hash", + "get_block_header", + "get_block_metadata", + "get_block_number", + "get_block_runtime_info", + "get_block_runtime_version_for", + "get_chain_finalised_head", + "get_chain_head", + "get_constant", + "get_events", + "get_extrinsics", + "get_metadata_call_function", + "get_metadata_constant", + "get_metadata_error", + "get_metadata_errors", + "get_metadata_module", + "get_metadata_modules", + "get_metadata_runtime_call_function", + "get_metadata_runtime_call_functions", + "get_metadata_storage_function", + "get_metadata_storage_functions", + "get_parent_block_hash", + "get_payment_info", + "get_storage_item", + "get_type_definition", + "get_type_registry", + "init_runtime", + "initialize", + "query", + "query_map", + "query_multi", + "query_multiple", + "retrieve_extrinsic_by_identifier", + "rpc_request", + "runtime_call", + "submit_extrinsic", + "subscribe_block_headers", + "supports_rpc_method", +] + +RETRY_PROPS = ["properties", "version", "token_decimals", "token_symbol", "name"] + + +class RetrySyncSubstrate(SubstrateInterface): + """ + A subclass of SubstrateInterface that allows for handling chain failures by using backup chains. If a sustained + network failure is encountered on a chain endpoint, the object will initialize a new connection on the next chain in + the `fallback_chains` list. If the `retry_forever` flag is set, upon reaching the last chain in `fallback_chains`, + the connection will attempt to iterate over the list (starting with `url`) again. + + E.g. + ``` + substrate = RetrySyncSubstrate( + "wss://entrypoint-finney.opentensor.ai:443", + fallback_chains=["ws://127.0.0.1:9946"] + ) + ``` + In this case, if there is a failure on entrypoint-finney, the connection will next attempt to hit localhost. If this + also fails, a `MaxRetriesExceeded` exception will be raised. + + ``` + substrate = RetrySyncSubstrate( + "wss://entrypoint-finney.opentensor.ai:443", + fallback_chains=["ws://127.0.0.1:9946"], + retry_forever=True + ) + ``` + In this case, rather than a MaxRetriesExceeded exception being raised upon failure of the second chain (localhost), + the object will again being to initialize a new connection on entrypoint-finney, and then localhost, and so on and + so forth. + """ + + def __init__( + self, + url: str, + use_remote_preset: bool = False, + fallback_chains: Optional[list[str]] = None, + retry_forever: bool = False, + ss58_format: Optional[int] = None, + type_registry: Optional[dict] = None, + type_registry_preset: Optional[str] = None, + chain_name: str = "", + max_retries: int = 5, + retry_timeout: float = 60.0, + _mock: bool = False, + ): + fallback_chains = fallback_chains or [] + self.fallback_chains = ( + iter(fallback_chains) + if not retry_forever + else cycle(fallback_chains + [url]) + ) + self.use_remote_preset = use_remote_preset + self.chain_name = chain_name + self._mock = _mock + self.retry_timeout = retry_timeout + self.max_retries = max_retries + self.chain_endpoint = url + self.url = url + initialized = False + for chain_url in [url] + fallback_chains: + try: + self.chain_endpoint = chain_url + self.url = chain_url + super().__init__( + url=chain_url, + ss58_format=ss58_format, + type_registry=type_registry, + use_remote_preset=use_remote_preset, + type_registry_preset=type_registry_preset, + chain_name=chain_name, + _mock=_mock, + retry_timeout=retry_timeout, + max_retries=max_retries, + ) + initialized = True + logger.info(f"Connected to {chain_url}") + break + except ConnectionError: + logger.warning(f"Unable to connect to {chain_url}") + if not initialized: + raise ConnectionError( + f"Unable to connect at any chains specified: {[url] + fallback_chains}" + ) + # "connect" is only used by SubstrateInterface, not AsyncSubstrateInterface + retry_methods = ["connect"] + RETRY_METHODS + self._original_methods = { + method: getattr(self, method) for method in retry_methods + } + for method in retry_methods: + setattr(self, method, partial(self._retry, method)) + + def _retry(self, method, *args, **kwargs): + method_ = self._original_methods[method] + try: + return method_(*args, **kwargs) + except ( + MaxRetriesExceeded, + ConnectionError, + EOFError, + ConnectionClosed, + TimeoutError, + ) as e: + try: + self._reinstantiate_substrate(e) + return method_(*args, **kwargs) + except StopIteration: + logger.error( + f"Max retries exceeded with {self.url}. No more fallback chains." + ) + raise MaxRetriesExceeded + + def _reinstantiate_substrate(self, e: Optional[Exception] = None) -> None: + next_network = next(self.fallback_chains) + self.ws.close() + if e.__class__ == MaxRetriesExceeded: + logger.error( + f"Max retries exceeded with {self.url}. Retrying with {next_network}." + ) + else: + logger.error(f"Connection error. Trying again with {next_network}") + self.url = next_network + self.chain_endpoint = next_network + self.initialized = False + self.ws = self.connect(init=True) + if not self._mock: + self.initialize() + + +class RetryAsyncSubstrate(AsyncSubstrateInterface): + """ + A subclass of AsyncSubstrateInterface that allows for handling chain failures by using backup chains. If a + sustained network failure is encountered on a chain endpoint, the object will initialize a new connection on + the next chain in the `fallback_chains` list. If the `retry_forever` flag is set, upon reaching the last chain + in `fallback_chains`, the connection will attempt to iterate over the list (starting with `url`) again. + + E.g. + ``` + substrate = RetryAsyncSubstrate( + "wss://entrypoint-finney.opentensor.ai:443", + fallback_chains=["ws://127.0.0.1:9946"] + ) + ``` + In this case, if there is a failure on entrypoint-finney, the connection will next attempt to hit localhost. If this + also fails, a `MaxRetriesExceeded` exception will be raised. + + ``` + substrate = RetryAsyncSubstrate( + "wss://entrypoint-finney.opentensor.ai:443", + fallback_chains=["ws://127.0.0.1:9946"], + retry_forever=True + ) + ``` + In this case, rather than a MaxRetriesExceeded exception being raised upon failure of the second chain (localhost), + the object will again being to initialize a new connection on entrypoint-finney, and then localhost, and so on and + so forth. + """ + + def __init__( + self, + url: str, + use_remote_preset: bool = False, + fallback_chains: Optional[list[str]] = None, + retry_forever: bool = False, + ss58_format: Optional[int] = None, + type_registry: Optional[dict] = None, + type_registry_preset: Optional[str] = None, + chain_name: str = "", + max_retries: int = 5, + retry_timeout: float = 60.0, + _mock: bool = False, + ): + fallback_chains = fallback_chains or [] + self.fallback_chains = ( + iter(fallback_chains) + if not retry_forever + else cycle(fallback_chains + [url]) + ) + self.use_remote_preset = use_remote_preset + self.chain_name = chain_name + self._mock = _mock + self.retry_timeout = retry_timeout + self.max_retries = max_retries + super().__init__( + url=url, + ss58_format=ss58_format, + type_registry=type_registry, + use_remote_preset=use_remote_preset, + type_registry_preset=type_registry_preset, + chain_name=chain_name, + _mock=_mock, + retry_timeout=retry_timeout, + max_retries=max_retries, + ) + self._original_methods = { + method: getattr(self, method) for method in RETRY_METHODS + } + for method in RETRY_METHODS: + setattr(self, method, partial(self._retry, method)) + + async def _reinstantiate_substrate(self, e: Optional[Exception] = None) -> None: + next_network = next(self.fallback_chains) + if e.__class__ == MaxRetriesExceeded: + logger.error( + f"Max retries exceeded with {self.url}. Retrying with {next_network}." + ) + else: + logger.error(f"Connection error. Trying again with {next_network}") + try: + await self.ws.shutdown() + except AttributeError: + pass + if self._forgettable_task is not None: + self._forgettable_task: asyncio.Task + self._forgettable_task.cancel() + try: + await self._forgettable_task + except asyncio.CancelledError: + pass + self.chain_endpoint = next_network + self.url = next_network + self.ws = Websocket( + next_network, + options={ + "max_size": self.ws_max_size, + "write_limit": 2**16, + }, + ) + self._initialized = False + self._initializing = False + await self.initialize() + + async def _retry(self, method, *args, **kwargs): + method_ = self._original_methods[method] + try: + return await method_(*args, **kwargs) + except ( + MaxRetriesExceeded, + ConnectionError, + ConnectionClosed, + EOFError, + socket.gaierror, + ) as e: + try: + await self._reinstantiate_substrate(e) + return await method_(*args, **kwargs) + except StopAsyncIteration: + logger.error( + f"Max retries exceeded with {self.url}. No more fallback chains." + ) + raise MaxRetriesExceeded diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index a3a9f4d..4c91fd2 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -1,27 +1,30 @@ import functools import logging +import socket from hashlib import blake2b from typing import Optional, Union, Callable, Any -from bittensor_wallet.keypair import Keypair -from bittensor_wallet.utils import SS58_FORMAT from bt_decode import MetadataV15, PortableRegistry, decode as decode_by_type_string from scalecodec import ( GenericCall, GenericExtrinsic, GenericRuntimeCallDefinition, ss58_encode, + MultiAccountId, ) from scalecodec.base import RuntimeConfigurationObject, ScaleBytes, ScaleType from websockets.sync.client import connect from websockets.exceptions import ConnectionClosed +from async_substrate_interface.const import SS58_FORMAT from async_substrate_interface.errors import ( ExtrinsicNotFound, SubstrateRequestException, BlockNotFound, MaxRetriesExceeded, + MetadataAtVersionNotFound, ) +from async_substrate_interface.protocols import Keypair from async_substrate_interface.types import ( SubstrateMixin, RuntimeCache, @@ -510,7 +513,6 @@ def __init__( "strict_scale_decode": True, } self.initialized = False - self._forgettable_task = None self.ss58_format = ss58_format self.type_registry = type_registry self.type_registry_preset = type_registry_preset @@ -531,8 +533,10 @@ def __enter__(self): return self def __del__(self): - self.ws.close() - print("DELETING SUBSTATE") + try: + self.ws.close() + except AttributeError: + pass # self.ws.protocol.fail(code=1006) # ABNORMAL_CLOSURE def initialize(self): @@ -584,13 +588,19 @@ def name(self): def connect(self, init=False): if init is True: - return connect(self.chain_endpoint, max_size=self.ws_max_size) + try: + return connect(self.chain_endpoint, max_size=self.ws_max_size) + except (ConnectionError, socket.gaierror) as e: + raise ConnectionError(e) else: if not self.ws.close_code: return self.ws else: - self.ws = connect(self.chain_endpoint, max_size=self.ws_max_size) - return self.ws + try: + self.ws = connect(self.chain_endpoint, max_size=self.ws_max_size) + return self.ws + except (ConnectionError, socket.gaierror) as e: + raise ConnectionError(e) def get_storage_item( self, module: str, storage_function: str, block_hash: str = None @@ -614,11 +624,20 @@ def _get_current_block_hash( def _load_registry_at_block(self, block_hash: Optional[str]) -> MetadataV15: # Should be called for any block that fails decoding. # Possibly the metadata was different. - metadata_rpc_result = self.rpc_request( - "state_call", - ["Metadata_metadata_at_version", self.metadata_version_hex], - block_hash=block_hash, - ) + try: + metadata_rpc_result = self.rpc_request( + "state_call", + ["Metadata_metadata_at_version", self.metadata_version_hex], + block_hash=block_hash, + ) + except SubstrateRequestException as e: + if ( + "Client error: Execution failed: Other: Exported method Metadata_metadata_at_version is not found" + in e.args + ): + raise MetadataAtVersionNotFound + else: + raise e metadata_option_hex_str = metadata_rpc_result["result"] metadata_option_bytes = bytes.fromhex(metadata_option_hex_str[2:]) metadata = MetadataV15.decode_from_metadata_option(metadata_option_bytes) @@ -796,6 +815,126 @@ def create_storage_key( metadata=self.runtime.metadata, ) + def subscribe_storage( + self, + storage_keys: list[StorageKey], + subscription_handler: Callable[[StorageKey, Any, str], Any], + ): + """ + + Subscribe to provided storage_keys and keep tracking until `subscription_handler` returns a value + + Example of a StorageKey: + ``` + StorageKey.create_from_storage_function( + "System", "Account", ["5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY"] + ) + ``` + + Example of a subscription handler: + ``` + def subscription_handler(storage_key, obj, subscription_id): + if obj is not None: + # the subscription will run until your subscription_handler returns something other than `None` + return obj + ``` + + Args: + storage_keys: StorageKey list of storage keys to subscribe to + subscription_handler: function to handle value changes of subscription + + """ + self.init_runtime() + + storage_key_map = {s.to_hex(): s for s in storage_keys} + + def result_handler( + message: dict, subscription_id: str + ) -> tuple[bool, Optional[Any]]: + result_found = False + subscription_result = None + if "params" in message: + # Process changes + for change_storage_key, change_data in message["params"]["result"][ + "changes" + ]: + # Check for target storage key + storage_key = storage_key_map[change_storage_key] + + if change_data is not None: + change_scale_type = storage_key.value_scale_type + result_found = True + elif ( + storage_key.metadata_storage_function.value["modifier"] + == "Default" + ): + # Fallback to default value of storage function if no result + change_scale_type = storage_key.value_scale_type + change_data = ( + storage_key.metadata_storage_function.value_object[ + "default" + ].value_object + ) + else: + # No result is interpreted as an Option<...> result + change_scale_type = f"Option<{storage_key.value_scale_type}>" + change_data = ( + storage_key.metadata_storage_function.value_object[ + "default" + ].value_object + ) + + # Decode SCALE result data + updated_obj = self.decode_scale( + type_string=change_scale_type, + scale_bytes=hex_to_bytes(change_data), + ) + + subscription_result = subscription_handler( + storage_key, updated_obj, subscription_id + ) + + if subscription_result is not None: + # Handler returned end result: unsubscribe from further updates + self.rpc_request("state_unsubscribeStorage", [subscription_id]) + + return result_found, subscription_result + + if not callable(subscription_handler): + raise ValueError("Provided `subscription_handler` is not callable") + + return self.rpc_request( + "state_subscribeStorage", + [[s.to_hex() for s in storage_keys]], + result_handler=result_handler, + ) + + def retrieve_pending_extrinsics(self) -> list: + """ + Retrieves and decodes pending extrinsics from the node's transaction pool + + Returns: + list of extrinsics + """ + + runtime = self.init_runtime() + + result_data = self.rpc_request("author_pendingExtrinsics", []) + + extrinsics = [] + + for extrinsic_data in result_data["result"]: + extrinsic = runtime.runtime_config.create_scale_object( + "Extrinsic", metadata=runtime.metadata + ) + extrinsic.decode( + ScaleBytes(extrinsic_data), + check_remaining=self.config.get("strict_scale_decode"), + ) + extrinsics.append(extrinsic) + + return extrinsics + def get_metadata_storage_functions(self, block_hash=None) -> list: """ Retrieves a list of all storage functions in metadata active at given block_hash (or chaintip if block_hash is @@ -946,6 +1085,41 @@ def get_metadata_runtime_call_function( return runtime_call_def_obj + def get_metadata_runtime_call_function( + self, api: str, method: str + ) -> GenericRuntimeCallDefinition: + """ + Get details of a runtime API call + + Args: + api: Name of the runtime API e.g. 'TransactionPaymentApi' + method: Name of the method e.g. 'query_fee_details' + + Returns: + GenericRuntimeCallDefinition + """ + self.init_runtime() + + try: + runtime_call_def = self.runtime_config.type_registry["runtime_api"][api][ + "methods" + ][method] + runtime_call_def["api"] = api + runtime_call_def["method"] = method + runtime_api_types = self.runtime_config.type_registry["runtime_api"][ + api + ].get("types", {}) + except KeyError: + raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") + + # Add runtime API types to registry + self.runtime_config.update_type_registry_types(runtime_api_types) + + runtime_call_def_obj = self.create_scale_object("RuntimeCallDefinition") + runtime_call_def_obj.encode(runtime_call_def) + + return runtime_call_def_obj + def _get_block_handler( self, block_hash: str, @@ -1141,6 +1315,8 @@ def result_handler(message: dict, subscription_id: str) -> tuple[Any, bool]: response["result"]["block"], block_data_hash=block_hash ) + get_block_handler = _get_block_handler + def get_block( self, block_hash: Optional[str] = None, @@ -1416,6 +1592,21 @@ def convert_event_data(data): events.append(convert_event_data(item)) return events + def get_metadata(self, block_hash=None) -> MetadataV15: + """ + Returns `MetadataVersioned` object for given block_hash or chaintip if block_hash is omitted + + + Args: + block_hash + + Returns: + MetadataVersioned + """ + runtime = self.init_runtime(block_hash=block_hash) + + return runtime.metadata_v15 + @functools.lru_cache(maxsize=512) def get_parent_block_hash(self, block_hash): block_header = self.rpc_request("chain_getHeader", [block_hash]) @@ -1429,6 +1620,33 @@ def get_parent_block_hash(self, block_hash): return block_hash return parent_block_hash + def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any: + """ + A pass-though to existing JSONRPC method `state_getStorage`/`state_getStorageAt` + + Args: + block_hash: hash of the block + storage_key: storage key to query + + Returns: + result of the query + + """ + + if self.supports_rpc_method("state_getStorageAt"): + response = self.rpc_request("state_getStorageAt", [storage_key, block_hash]) + else: + response = self.rpc_request("state_getStorage", [storage_key, block_hash]) + + if "result" in response: + return response.get("result") + elif "error" in response: + raise SubstrateRequestException(response["error"]["message"]) + else: + raise SubstrateRequestException( + "Unknown error occurred during retrieval of events" + ) + @functools.lru_cache(maxsize=16) def get_block_runtime_info(self, block_hash: str) -> dict: """ @@ -1437,6 +1655,8 @@ def get_block_runtime_info(self, block_hash: str) -> dict: response = self.rpc_request("state_getRuntimeVersion", [block_hash]) return response.get("result") + get_block_runtime_version = get_block_runtime_info + @functools.lru_cache(maxsize=512) def get_block_runtime_version_for(self, block_hash: str): """ @@ -2161,6 +2381,34 @@ def create_signed_extrinsic( return extrinsic + def create_unsigned_extrinsic(self, call: GenericCall) -> GenericExtrinsic: + """ + Create unsigned extrinsic for given `Call` + + Args: + call: GenericCall the call the extrinsic should contain + + Returns: + GenericExtrinsic + """ + + runtime = self.init_runtime() + + # Create extrinsic + extrinsic = self.runtime_config.create_scale_object( + type_string="Extrinsic", metadata=runtime.metadata + ) + + extrinsic.encode( + { + "call_function": call.value["call_function"], + "call_module": call.value["call_module"], + "call_args": call.value["call_args"], + } + ) + + return extrinsic + def get_chain_finalised_head(self): """ A pass-though to existing JSONRPC method `chain_getFinalizedHead` @@ -2334,6 +2582,29 @@ def get_account_next_index(self, account_address: str) -> int: nonce_obj = self.rpc_request("account_nextIndex", [account_address]) return nonce_obj["result"] + def get_metadata_constants(self, block_hash=None) -> list[dict]: + """ + Retrieves a list of all constants in metadata active at given block_hash (or chaintip if block_hash is omitted) + + Args: + block_hash: hash of the block + + Returns: + list of constants + """ + + runtime = self.init_runtime(block_hash=block_hash) + + constant_list = [] + + for module_idx, module in enumerate(self.metadata.pallets): + for constant in module.constants or []: + constant_list.append( + self.serialize_constant(constant, module, runtime.runtime_version) + ) + + return constant_list + def get_metadata_constant(self, module_name, constant_name, block_hash=None): """ Retrieves the details of a constant for given module name, call function name and block_hash @@ -2698,6 +2969,100 @@ def query_map( ignore_decoding_errors=ignore_decoding_errors, ) + def create_multisig_extrinsic( + self, + call: GenericCall, + keypair: Keypair, + multisig_account: MultiAccountId, + max_weight: Optional[Union[dict, int]] = None, + era: dict = None, + nonce: int = None, + tip: int = 0, + tip_asset_id: int = None, + signature: Union[bytes, str] = None, + ) -> GenericExtrinsic: + """ + Create a Multisig extrinsic that will be signed by one of the signatories. Checks on-chain if the threshold + of the multisig account is reached and try to execute the call accordingly. + + Args: + call: GenericCall to create extrinsic for + keypair: Keypair of the signatory to approve given call + multisig_account: MultiAccountId to use of origin of the extrinsic (see `generate_multisig_account()`) + max_weight: Maximum allowed weight to execute the call ( Uses `get_payment_info()` by default) + era: Specify mortality in blocks in follow format: {'period': [amount_blocks]} If omitted the extrinsic is + immortal + nonce: nonce to include in extrinsics, if omitted the current nonce is retrieved on-chain + tip: The tip for the block author to gain priority during network congestion + tip_asset_id: Optional asset ID with which to pay the tip + signature: Optionally provide signature if externally signed + + Returns: + GenericExtrinsic + """ + if max_weight is None: + payment_info = self.get_payment_info(call, keypair) + max_weight = payment_info["weight"] + + # Check if call has existing approvals + multisig_details = self.query( + "Multisig", "Multisigs", [multisig_account.value, call.call_hash] + ) + + if multisig_details.value: + maybe_timepoint = multisig_details.value["when"] + else: + maybe_timepoint = None + + # Compose 'as_multi' when final, 'approve_as_multi' otherwise + if ( + multisig_details.value + and len(multisig_details.value["approvals"]) + 1 + == multisig_account.threshold + ): + multi_sig_call = self.compose_call( + "Multisig", + "as_multi", + { + "other_signatories": [ + s + for s in multisig_account.signatories + if s != f"0x{keypair.public_key.hex()}" + ], + "threshold": multisig_account.threshold, + "maybe_timepoint": maybe_timepoint, + "call": call, + "store_call": False, + "max_weight": max_weight, + }, + ) + else: + multi_sig_call = self.compose_call( + "Multisig", + "approve_as_multi", + { + "other_signatories": [ + s + for s in multisig_account.signatories + if s != f"0x{keypair.public_key.hex()}" + ], + "threshold": multisig_account.threshold, + "maybe_timepoint": maybe_timepoint, + "call_hash": call.call_hash, + "max_weight": max_weight, + }, + ) + + return self.create_signed_extrinsic( + multi_sig_call, + keypair, + era=era, + nonce=nonce, + tip=tip, + tip_asset_id=tip_asset_id, + signature=signature, + ) + def submit_extrinsic( self, extrinsic: GenericExtrinsic, @@ -2832,6 +3197,55 @@ def get_metadata_call_function( return call return None + def get_metadata_events(self, block_hash=None) -> list[dict]: + """ + Retrieves a list of all events in metadata active for given block_hash (or chaintip if block_hash is omitted) + + Args: + block_hash + + Returns: + list of module events + """ + + runtime = self.init_runtime(block_hash=block_hash) + + event_list = [] + + for event_index, (module, event) in self.metadata.event_index.items(): + event_list.append( + self.serialize_module_event( + module, event, runtime.runtime_version, event_index + ) + ) + + return event_list + + def get_metadata_event( + self, module_name, event_name, block_hash=None + ) -> Optional[Any]: + """ + Retrieves the details of an event for given module name, call function name and block_hash + (or chaintip if block_hash is omitted) + + Args: + module_name: name of the module to call + event_name: name of the event + block_hash: hash of the block + + Returns: + Metadata event + + """ + + runtime = self.init_runtime(block_hash=block_hash) + + for pallet in runtime.metadata.pallets: + if pallet.name == module_name and pallet.events: + for event in pallet.events: + if event.name == event_name: + return event + def get_block_number(self, block_hash: Optional[str] = None) -> int: """Async version of `substrateinterface.base.get_block_number` method.""" response = self.rpc_request("chain_getHeader", [block_hash]) diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index 754b860..5a83895 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -7,12 +7,12 @@ from typing import Optional, Union, Any from bt_decode import PortableRegistry, encode as encode_by_type_string -from bittensor_wallet.utils import SS58_FORMAT from scalecodec import ss58_encode, ss58_decode, is_valid_ss58_address from scalecodec.base import RuntimeConfigurationObject, ScaleBytes from scalecodec.type_registry import load_type_registry_preset -from scalecodec.types import GenericCall, ScaleType +from scalecodec.types import GenericCall, ScaleType, MultiAccountId +from .const import SS58_FORMAT from .utils import json @@ -919,3 +919,27 @@ def _encode_account_id(self, account) -> bytes: if isinstance(account, bytes): return account # Already encoded return bytes.fromhex(ss58_decode(account, SS58_FORMAT)) # SS58 string + + def generate_multisig_account( + self, signatories: list, threshold: int + ) -> MultiAccountId: + """ + Generate deterministic Multisig account with supplied signatories and threshold + + Args: + signatories: List of signatories + threshold: Amount of approvals needed to execute + + Returns: + MultiAccountId + """ + + multi_sig_account = MultiAccountId.create_from_account_list( + signatories, threshold + ) + + multi_sig_account.ss58_address = ss58_encode( + multi_sig_account.value.replace("0x", ""), self.ss58_format + ) + + return multi_sig_account diff --git a/pyproject.toml b/pyproject.toml index 8fe5ee8..cdfa7e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.1.1" +version = "1.2.0" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } @@ -9,7 +9,6 @@ keywords = ["substrate", "development", "bittensor"] dependencies = [ "wheel", "asyncstdlib~=3.13.0", - "bittensor-wallet>=2.1.3", "bt-decode==v0.6.0", "scalecodec~=1.2.11", "websockets>=14.1", diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0e312e3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,26 @@ +import subprocess +from collections import namedtuple + +CONTAINER_NAME_PREFIX = "test_local_chain_" +LOCALNET_IMAGE_NAME = "ghcr.io/opentensor/subtensor-localnet:devnet-ready" + +Container = namedtuple("Container", ["process", "name", "uri"]) + + +def start_docker_container(exposed_port, name_salt: str): + container_name = f"{CONTAINER_NAME_PREFIX}{name_salt}" + + # Command to start container + cmds = [ + "docker", + "run", + "--rm", + "--name", + container_name, + "-p", + f"{exposed_port}:9945", + LOCALNET_IMAGE_NAME, + ] + + proc = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return Container(proc, container_name, f"ws://127.0.0.1:{exposed_port}") diff --git a/tests/test_substrate_addons.py b/tests/test_substrate_addons.py new file mode 100644 index 0000000..686a028 --- /dev/null +++ b/tests/test_substrate_addons.py @@ -0,0 +1,74 @@ +import threading +import subprocess + +import pytest +import time + +from async_substrate_interface.substrate_addons import RetrySyncSubstrate +from async_substrate_interface.errors import MaxRetriesExceeded +from tests.conftest import start_docker_container + +LATENT_LITE_ENTRYPOINT = "wss://lite.sub.latent.to:443" + + +@pytest.fixture(scope="function") +def docker_containers(): + processes = (start_docker_container(9945, 9945), start_docker_container(9946, 9946)) + try: + yield processes + + finally: + for process in processes: + subprocess.run(["docker", "kill", process.name]) + process.process.kill() + + +@pytest.fixture(scope="function") +def single_local_chain(): + process = start_docker_container(9945, 9945) + try: + yield process + finally: + print("TRIGGERED KILL") + subprocess.run(["docker", "kill", process.name]) + process.process.kill() + + +def test_retry_sync_substrate(single_local_chain): + time.sleep(10) + with RetrySyncSubstrate( + single_local_chain.uri, fallback_chains=[LATENT_LITE_ENTRYPOINT] + ) as substrate: + for i in range(10): + assert substrate.get_chain_head().startswith("0x") + if i == 8: + subprocess.run(["docker", "stop", single_local_chain.name]) + if i > 8: + assert substrate.chain_endpoint == LATENT_LITE_ENTRYPOINT + time.sleep(2) + + +def test_retry_sync_substrate_max_retries(docker_containers): + time.sleep(10) + with RetrySyncSubstrate( + docker_containers[0].uri, fallback_chains=[docker_containers[1].uri] + ) as substrate: + for i in range(5): + print("EYE EQUALS", i) + assert substrate.get_chain_head().startswith("0x") + if i == 2: + subprocess.run(["docker", "pause", docker_containers[0].name]) + if i == 3: + assert substrate.chain_endpoint == docker_containers[1].uri + if i == 4: + subprocess.run(["docker", "pause", docker_containers[1].name]) + with pytest.raises(MaxRetriesExceeded): + substrate.get_chain_head().startswith("0x") + time.sleep(2) + + +def test_retry_sync_substrate_offline(): + with pytest.raises(ConnectionError): + RetrySyncSubstrate( + "ws://127.0.0.1:9945", fallback_chains=["ws://127.0.0.1:9946"] + )