diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 8df2341..f10b582 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -734,15 +734,14 @@ async def initialize(self): """ Initialize the connection to the chain. """ - async with self._lock: - self._initializing = True - if not self.initialized: - if not self._chain: - chain = await self.rpc_request("system_chain", []) - self._chain = chain.get("result") - await self.init_runtime() - self.initialized = True - self._initializing = False + self._initializing = True + if not self.initialized: + if not self._chain: + chain = await self.rpc_request("system_chain", []) + self._chain = chain.get("result") + await self.init_runtime() + self.initialized = True + self._initializing = False async def __aexit__(self, exc_type, exc_val, exc_tb): pass diff --git a/async_substrate_interface/substrate_addons.py b/async_substrate_interface/substrate_addons.py new file mode 100644 index 0000000..5edb26a --- /dev/null +++ b/async_substrate_interface/substrate_addons.py @@ -0,0 +1,325 @@ +""" +A number of "plugins" for SubstrateInterface (and AsyncSubstrateInterface). At initial creation, it contains only +Retry (sync and async versions). +""" + +import asyncio +import logging +import socket +from functools import partial +from itertools import cycle +from typing import Optional + +from websockets.exceptions import ConnectionClosed + +from async_substrate_interface.async_substrate import AsyncSubstrateInterface, Websocket +from async_substrate_interface.errors import MaxRetriesExceeded +from async_substrate_interface.sync_substrate import SubstrateInterface + +logger = logging.getLogger("async_substrate_interface") + + +RETRY_METHODS = [ + "_get_block_handler", + "close", + "compose_call", + "create_scale_object", + "create_signed_extrinsic", + "create_storage_key", + "decode_scale", + "encode_scale", + "generate_signature_payload", + "get_account_next_index", + "get_account_nonce", + "get_block", + "get_block_hash", + "get_block_header", + "get_block_metadata", + "get_block_number", + "get_block_runtime_info", + "get_block_runtime_version_for", + "get_chain_finalised_head", + "get_chain_head", + "get_constant", + "get_events", + "get_extrinsics", + "get_metadata_call_function", + "get_metadata_constant", + "get_metadata_error", + "get_metadata_errors", + "get_metadata_module", + "get_metadata_modules", + "get_metadata_runtime_call_function", + "get_metadata_runtime_call_functions", + "get_metadata_storage_function", + "get_metadata_storage_functions", + "get_parent_block_hash", + "get_payment_info", + "get_storage_item", + "get_type_definition", + "get_type_registry", + "init_runtime", + "initialize", + "query", + "query_map", + "query_multi", + "query_multiple", + "retrieve_extrinsic_by_identifier", + "rpc_request", + "runtime_call", + "submit_extrinsic", + "subscribe_block_headers", + "supports_rpc_method", +] + +RETRY_PROPS = ["properties", "version", "token_decimals", "token_symbol", "name"] + + +class RetrySyncSubstrate(SubstrateInterface): + """ + A subclass of SubstrateInterface that allows for handling chain failures by using backup chains. If a sustained + network failure is encountered on a chain endpoint, the object will initialize a new connection on the next chain in + the `fallback_chains` list. If the `retry_forever` flag is set, upon reaching the last chain in `fallback_chains`, + the connection will attempt to iterate over the list (starting with `url`) again. + + E.g. + ``` + substrate = RetrySyncSubstrate( + "wss://entrypoint-finney.opentensor.ai:443", + fallback_chains=["ws://127.0.0.1:9946"] + ) + ``` + In this case, if there is a failure on entrypoint-finney, the connection will next attempt to hit localhost. If this + also fails, a `MaxRetriesExceeded` exception will be raised. + + ``` + substrate = RetrySyncSubstrate( + "wss://entrypoint-finney.opentensor.ai:443", + fallback_chains=["ws://127.0.0.1:9946"], + retry_forever=True + ) + ``` + In this case, rather than a MaxRetriesExceeded exception being raised upon failure of the second chain (localhost), + the object will again being to initialize a new connection on entrypoint-finney, and then localhost, and so on and + so forth. + """ + + def __init__( + self, + url: str, + use_remote_preset: bool = False, + fallback_chains: Optional[list[str]] = None, + retry_forever: bool = False, + ss58_format: Optional[int] = None, + type_registry: Optional[dict] = None, + type_registry_preset: Optional[str] = None, + chain_name: str = "", + max_retries: int = 5, + retry_timeout: float = 60.0, + _mock: bool = False, + ): + fallback_chains = fallback_chains or [] + self.fallback_chains = ( + iter(fallback_chains) + if not retry_forever + else cycle(fallback_chains + [url]) + ) + self.use_remote_preset = use_remote_preset + self.chain_name = chain_name + self._mock = _mock + self.retry_timeout = retry_timeout + self.max_retries = max_retries + self.chain_endpoint = url + self.url = url + initialized = False + for chain_url in [url] + fallback_chains: + try: + self.chain_endpoint = chain_url + self.url = chain_url + super().__init__( + url=chain_url, + ss58_format=ss58_format, + type_registry=type_registry, + use_remote_preset=use_remote_preset, + type_registry_preset=type_registry_preset, + chain_name=chain_name, + _mock=_mock, + retry_timeout=retry_timeout, + max_retries=max_retries, + ) + initialized = True + logger.info(f"Connected to {chain_url}") + break + except ConnectionError: + logger.warning(f"Unable to connect to {chain_url}") + if not initialized: + raise ConnectionError( + f"Unable to connect at any chains specified: {[url] + fallback_chains}" + ) + # "connect" is only used by SubstrateInterface, not AsyncSubstrateInterface + retry_methods = ["connect"] + RETRY_METHODS + self._original_methods = { + method: getattr(self, method) for method in retry_methods + } + for method in retry_methods: + setattr(self, method, partial(self._retry, method)) + + def _retry(self, method, *args, **kwargs): + method_ = self._original_methods[method] + try: + return method_(*args, **kwargs) + except ( + MaxRetriesExceeded, + ConnectionError, + EOFError, + ConnectionClosed, + TimeoutError, + ) as e: + try: + self._reinstantiate_substrate(e) + return method_(*args, **kwargs) + except StopIteration: + logger.error( + f"Max retries exceeded with {self.url}. No more fallback chains." + ) + raise MaxRetriesExceeded + + def _reinstantiate_substrate(self, e: Optional[Exception] = None) -> None: + next_network = next(self.fallback_chains) + self.ws.close() + if e.__class__ == MaxRetriesExceeded: + logger.error( + f"Max retries exceeded with {self.url}. Retrying with {next_network}." + ) + else: + logger.error(f"Connection error. Trying again with {next_network}") + self.url = next_network + self.chain_endpoint = next_network + self.initialized = False + self.ws = self.connect(init=True) + if not self._mock: + self.initialize() + + +class RetryAsyncSubstrate(AsyncSubstrateInterface): + """ + A subclass of AsyncSubstrateInterface that allows for handling chain failures by using backup chains. If a + sustained network failure is encountered on a chain endpoint, the object will initialize a new connection on + the next chain in the `fallback_chains` list. If the `retry_forever` flag is set, upon reaching the last chain + in `fallback_chains`, the connection will attempt to iterate over the list (starting with `url`) again. + + E.g. + ``` + substrate = RetryAsyncSubstrate( + "wss://entrypoint-finney.opentensor.ai:443", + fallback_chains=["ws://127.0.0.1:9946"] + ) + ``` + In this case, if there is a failure on entrypoint-finney, the connection will next attempt to hit localhost. If this + also fails, a `MaxRetriesExceeded` exception will be raised. + + ``` + substrate = RetryAsyncSubstrate( + "wss://entrypoint-finney.opentensor.ai:443", + fallback_chains=["ws://127.0.0.1:9946"], + retry_forever=True + ) + ``` + In this case, rather than a MaxRetriesExceeded exception being raised upon failure of the second chain (localhost), + the object will again being to initialize a new connection on entrypoint-finney, and then localhost, and so on and + so forth. + """ + + def __init__( + self, + url: str, + use_remote_preset: bool = False, + fallback_chains: Optional[list[str]] = None, + retry_forever: bool = False, + ss58_format: Optional[int] = None, + type_registry: Optional[dict] = None, + type_registry_preset: Optional[str] = None, + chain_name: str = "", + max_retries: int = 5, + retry_timeout: float = 60.0, + _mock: bool = False, + ): + fallback_chains = fallback_chains or [] + self.fallback_chains = ( + iter(fallback_chains) + if not retry_forever + else cycle(fallback_chains + [url]) + ) + self.use_remote_preset = use_remote_preset + self.chain_name = chain_name + self._mock = _mock + self.retry_timeout = retry_timeout + self.max_retries = max_retries + super().__init__( + url=url, + ss58_format=ss58_format, + type_registry=type_registry, + use_remote_preset=use_remote_preset, + type_registry_preset=type_registry_preset, + chain_name=chain_name, + _mock=_mock, + retry_timeout=retry_timeout, + max_retries=max_retries, + ) + self._original_methods = { + method: getattr(self, method) for method in RETRY_METHODS + } + for method in RETRY_METHODS: + setattr(self, method, partial(self._retry, method)) + + async def _reinstantiate_substrate(self, e: Optional[Exception] = None) -> None: + next_network = next(self.fallback_chains) + if e.__class__ == MaxRetriesExceeded: + logger.error( + f"Max retries exceeded with {self.url}. Retrying with {next_network}." + ) + else: + logger.error(f"Connection error. Trying again with {next_network}") + try: + await self.ws.shutdown() + except AttributeError: + pass + if self._forgettable_task is not None: + self._forgettable_task: asyncio.Task + self._forgettable_task.cancel() + try: + await self._forgettable_task + except asyncio.CancelledError: + pass + self.chain_endpoint = next_network + self.url = next_network + self.ws = Websocket( + next_network, + options={ + "max_size": self.ws_max_size, + "write_limit": 2**16, + }, + ) + self._initialized = False + self._initializing = False + await self.initialize() + + async def _retry(self, method, *args, **kwargs): + method_ = self._original_methods[method] + try: + return await method_(*args, **kwargs) + except ( + MaxRetriesExceeded, + ConnectionError, + ConnectionClosed, + EOFError, + socket.gaierror, + ) as e: + try: + await self._reinstantiate_substrate(e) + return await method_(*args, **kwargs) + except StopAsyncIteration: + logger.error( + f"Max retries exceeded with {self.url}. No more fallback chains." + ) + raise MaxRetriesExceeded diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index 21664d4..221a04a 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -1,5 +1,6 @@ import functools import logging +import socket from hashlib import blake2b from typing import Optional, Union, Callable, Any @@ -511,7 +512,6 @@ def __init__( "strict_scale_decode": True, } self.initialized = False - self._forgettable_task = None self.ss58_format = ss58_format self.type_registry = type_registry self.type_registry_preset = type_registry_preset @@ -587,13 +587,19 @@ def name(self): def connect(self, init=False): if init is True: - return connect(self.chain_endpoint, max_size=self.ws_max_size) + try: + return connect(self.chain_endpoint, max_size=self.ws_max_size) + except (ConnectionError, socket.gaierror) as e: + raise ConnectionError(e) else: if not self.ws.close_code: return self.ws else: - self.ws = connect(self.chain_endpoint, max_size=self.ws_max_size) - return self.ws + try: + self.ws = connect(self.chain_endpoint, max_size=self.ws_max_size) + return self.ws + except (ConnectionError, socket.gaierror) as e: + raise ConnectionError(e) def get_storage_item( self, module: str, storage_function: str, block_hash: str = None diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0e312e3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,26 @@ +import subprocess +from collections import namedtuple + +CONTAINER_NAME_PREFIX = "test_local_chain_" +LOCALNET_IMAGE_NAME = "ghcr.io/opentensor/subtensor-localnet:devnet-ready" + +Container = namedtuple("Container", ["process", "name", "uri"]) + + +def start_docker_container(exposed_port, name_salt: str): + container_name = f"{CONTAINER_NAME_PREFIX}{name_salt}" + + # Command to start container + cmds = [ + "docker", + "run", + "--rm", + "--name", + container_name, + "-p", + f"{exposed_port}:9945", + LOCALNET_IMAGE_NAME, + ] + + proc = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return Container(proc, container_name, f"ws://127.0.0.1:{exposed_port}") diff --git a/tests/test_substrate_addons.py b/tests/test_substrate_addons.py new file mode 100644 index 0000000..686a028 --- /dev/null +++ b/tests/test_substrate_addons.py @@ -0,0 +1,74 @@ +import threading +import subprocess + +import pytest +import time + +from async_substrate_interface.substrate_addons import RetrySyncSubstrate +from async_substrate_interface.errors import MaxRetriesExceeded +from tests.conftest import start_docker_container + +LATENT_LITE_ENTRYPOINT = "wss://lite.sub.latent.to:443" + + +@pytest.fixture(scope="function") +def docker_containers(): + processes = (start_docker_container(9945, 9945), start_docker_container(9946, 9946)) + try: + yield processes + + finally: + for process in processes: + subprocess.run(["docker", "kill", process.name]) + process.process.kill() + + +@pytest.fixture(scope="function") +def single_local_chain(): + process = start_docker_container(9945, 9945) + try: + yield process + finally: + print("TRIGGERED KILL") + subprocess.run(["docker", "kill", process.name]) + process.process.kill() + + +def test_retry_sync_substrate(single_local_chain): + time.sleep(10) + with RetrySyncSubstrate( + single_local_chain.uri, fallback_chains=[LATENT_LITE_ENTRYPOINT] + ) as substrate: + for i in range(10): + assert substrate.get_chain_head().startswith("0x") + if i == 8: + subprocess.run(["docker", "stop", single_local_chain.name]) + if i > 8: + assert substrate.chain_endpoint == LATENT_LITE_ENTRYPOINT + time.sleep(2) + + +def test_retry_sync_substrate_max_retries(docker_containers): + time.sleep(10) + with RetrySyncSubstrate( + docker_containers[0].uri, fallback_chains=[docker_containers[1].uri] + ) as substrate: + for i in range(5): + print("EYE EQUALS", i) + assert substrate.get_chain_head().startswith("0x") + if i == 2: + subprocess.run(["docker", "pause", docker_containers[0].name]) + if i == 3: + assert substrate.chain_endpoint == docker_containers[1].uri + if i == 4: + subprocess.run(["docker", "pause", docker_containers[1].name]) + with pytest.raises(MaxRetriesExceeded): + substrate.get_chain_head().startswith("0x") + time.sleep(2) + + +def test_retry_sync_substrate_offline(): + with pytest.raises(ConnectionError): + RetrySyncSubstrate( + "ws://127.0.0.1:9945", fallback_chains=["ws://127.0.0.1:9946"] + )