diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 598a882..3b96e66 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -58,7 +58,10 @@ get_next_id, rng as random, ) -from async_substrate_interface.utils.cache import async_sql_lru_cache, CachedFetcher +from async_substrate_interface.utils.cache import ( + async_sql_lru_cache, + cached_fetcher, +) from async_substrate_interface.utils.decoding import ( _determine_if_old_runtime_call, _bt_decode_to_dict_or_list, @@ -794,12 +797,6 @@ def __init__( self.registry_type_map = {} self.type_id_to_name = {} self._mock = _mock - self._block_hash_fetcher = CachedFetcher(512, self._get_block_hash) - self._parent_hash_fetcher = CachedFetcher(512, self._get_parent_block_hash) - self._runtime_info_fetcher = CachedFetcher(16, self._get_block_runtime_info) - self._runtime_version_for_fetcher = CachedFetcher( - 512, self._get_block_runtime_version_for - ) async def __aenter__(self): if not self._mock: @@ -1044,35 +1041,7 @@ async def init_runtime( if not runtime: self.last_block_hash = block_hash - runtime_block_hash = await self.get_parent_block_hash(block_hash) - - runtime_info = await self.get_block_runtime_info(runtime_block_hash) - - metadata, (metadata_v15, registry) = await asyncio.gather( - self.get_block_metadata(block_hash=runtime_block_hash, decode=True), - self._load_registry_at_block(block_hash=runtime_block_hash), - ) - if metadata is None: - # does this ever happen? - raise SubstrateRequestException( - f"No metadata for block '{runtime_block_hash}'" - ) - logger.debug( - f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node" - ) - - runtime = Runtime( - chain=self.chain, - runtime_config=self.runtime_config, - metadata=metadata, - type_registry=self.type_registry, - metadata_v15=metadata_v15, - runtime_info=runtime_info, - registry=registry, - ) - self.runtime_cache.add_item( - runtime_version=runtime_version, runtime=runtime - ) + runtime = await self.get_runtime_for_version(runtime_version, block_hash) self.load_runtime(runtime) @@ -1086,6 +1055,51 @@ async def init_runtime( self.ss58_format = ss58_prefix_constant return runtime + @cached_fetcher(max_size=16, cache_key_index=0) + async def get_runtime_for_version( + self, runtime_version: int, block_hash: Optional[str] = None + ) -> Runtime: + """ + Retrieves the `Runtime` for a given runtime version at a given block hash. + Args: + runtime_version: version of the runtime (from `get_block_runtime_version_for`) + block_hash: hash of the block to query + + Returns: + Runtime object for the given runtime version + """ + return await self._get_runtime_for_version(runtime_version, block_hash) + + async def _get_runtime_for_version( + self, runtime_version: int, block_hash: Optional[str] = None + ) -> Runtime: + runtime_block_hash = await self.get_parent_block_hash(block_hash) + runtime_info, metadata, (metadata_v15, registry) = await asyncio.gather( + self.get_block_runtime_info(runtime_block_hash), + self.get_block_metadata(block_hash=runtime_block_hash, decode=True), + self._load_registry_at_block(block_hash=runtime_block_hash), + ) + if metadata is None: + # does this ever happen? + raise SubstrateRequestException( + f"No metadata for block '{runtime_block_hash}'" + ) + logger.debug( + f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node" + ) + + runtime = Runtime( + chain=self.chain, + runtime_config=self.runtime_config, + metadata=metadata, + type_registry=self.type_registry, + metadata_v15=metadata_v15, + runtime_info=runtime_info, + registry=registry, + ) + self.runtime_cache.add_item(runtime_version=runtime_version, runtime=runtime) + return runtime + async def create_storage_key( self, pallet: str, @@ -1921,10 +1935,19 @@ async def get_metadata(self, block_hash=None) -> MetadataV15: return runtime.metadata_v15 - async def get_parent_block_hash(self, block_hash): - return await self._parent_hash_fetcher.execute(block_hash) + @cached_fetcher(max_size=512) + async def get_parent_block_hash(self, block_hash) -> str: + """ + Retrieves the block hash of the parent of the given block hash + Args: + block_hash: hash of the block to query + + Returns: + Hash of the parent block hash, or the original block hash (if it has not parent) + """ + return await self._get_parent_block_hash(block_hash) - async def _get_parent_block_hash(self, block_hash): + async def _get_parent_block_hash(self, block_hash) -> str: block_header = await self.rpc_request("chain_getHeader", [block_hash]) if block_header["result"] is None: @@ -1967,25 +1990,27 @@ async def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any: "Unknown error occurred during retrieval of events" ) + @cached_fetcher(max_size=16) async def get_block_runtime_info(self, block_hash: str) -> dict: - return await self._runtime_info_fetcher.execute(block_hash) + """ + Retrieve the runtime info of given block_hash + """ + return await self._get_block_runtime_info(block_hash) get_block_runtime_version = get_block_runtime_info async def _get_block_runtime_info(self, block_hash: str) -> dict: - """ - Retrieve the runtime info of given block_hash - """ response = await self.rpc_request("state_getRuntimeVersion", [block_hash]) return response.get("result") + @cached_fetcher(max_size=512) async def get_block_runtime_version_for(self, block_hash: str): - return await self._runtime_version_for_fetcher.execute(block_hash) - - async def _get_block_runtime_version_for(self, block_hash: str): """ Retrieve the runtime version of the parent of a given block_hash """ + return await self._get_block_runtime_version_for(block_hash) + + async def _get_block_runtime_version_for(self, block_hash: str): parent_block_hash = await self.get_parent_block_hash(block_hash) runtime_info = await self.get_block_runtime_info(parent_block_hash) if runtime_info is None: @@ -2296,8 +2321,17 @@ async def rpc_request( else: raise SubstrateRequestException(result[payload_id][0]) + @cached_fetcher(max_size=512) async def get_block_hash(self, block_id: int) -> str: - return await self._block_hash_fetcher.execute(block_id) + """ + Retrieves the hash of the specified block number + Args: + block_id: block number + + Returns: + Hash of the block + """ + return await self._get_block_hash(block_id) async def _get_block_hash(self, block_id: int) -> str: return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"] diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index fa4be3c..23bbf9f 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -1,14 +1,13 @@ import asyncio +import inspect from collections import OrderedDict import functools +import logging import os import pickle import sqlite3 from pathlib import Path -from typing import Callable, Any - -import asyncstdlib as a - +from typing import Callable, Any, Awaitable, Hashable, Optional USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False CACHE_LOCATION = ( @@ -19,6 +18,8 @@ else ":memory:" ) +logger = logging.getLogger("async_substrate_interface") + def _ensure_dir(): path = Path(CACHE_LOCATION).parent @@ -70,7 +71,7 @@ def _retrieve_from_cache(c, table_name, key, chain): if result is not None: return pickle.loads(result[0]) except (pickle.PickleError, sqlite3.Error) as e: - print(f"Cache error: {str(e)}") + logger.exception("Cache error", exc_info=e) pass @@ -82,7 +83,7 @@ def _insert_into_cache(c, conn, table_name, key, result, chain): ) conn.commit() except (pickle.PickleError, sqlite3.Error) as e: - print(f"Cache error: {str(e)}") + logger.exception("Cache error", exc_info=e) pass @@ -128,7 +129,7 @@ def inner(self, *args, **kwargs): def async_sql_lru_cache(maxsize=None): def decorator(func): - @a.lru_cache(maxsize=maxsize) + @cached_fetcher(max_size=maxsize) async def inner(self, *args, **kwargs): c, conn, table_name, key, result, chain, local_chain = ( _shared_inner_fn_logic(func, self, args, kwargs) @@ -147,6 +148,10 @@ async def inner(self, *args, **kwargs): class LRUCache: + """ + Basic Least-Recently-Used Cache, with simple methods `set` and `get` + """ + def __init__(self, max_size: int): self.max_size = max_size self.cache = OrderedDict() @@ -167,31 +172,121 @@ def get(self, key): class CachedFetcher: - def __init__(self, max_size: int, method: Callable): - self._inflight: dict[int, asyncio.Future] = {} + """ + Async caching class that allows the standard async LRU cache system, but also allows for concurrent + asyncio calls (with the same args) to use the same result of a single call. + + This should only be used for asyncio calls where the result is immutable. + + Concept and usage: + ``` + async def fetch(self, block_hash: str) -> str: + return await some_resource(block_hash) + + a1, a2, b = await asyncio.gather(fetch("a"), fetch("a"), fetch("b")) + ``` + + Here, you are making three requests, but you really only need to make two I/O requests + (one for "a", one for "b"), and while you wouldn't typically make a request like this directly, it's very + common in using this library to inadvertently make these requests y gathering multiple resources that depend + on the calls like this under the hood. + + By using + + ``` + @cached_fetcher(max_size=512) + async def fetch(self, block_hash: str) -> str: + return await some_resource(block_hash) + + a1, a2, b = await asyncio.gather(fetch("a"), fetch("a"), fetch("b")) + ``` + + You are only making two I/O calls, and a2 will simply use the result of a1 when it lands. + """ + + def __init__( + self, + max_size: int, + method: Callable[..., Awaitable[Any]], + cache_key_index: Optional[int] = 0, + ): + """ + Args: + max_size: max size of the cache (in items) + method: the function to cache + cache_key_index: if the method takes multiple args, this is the index of that cache key in the args list + (default is the first arg). By setting this to `None`, it will use all args as the cache key. + """ + self._inflight: dict[Hashable, asyncio.Future] = {} self._method = method self._cache = LRUCache(max_size=max_size) + self._cache_key_index = cache_key_index - async def execute(self, single_arg: Any) -> str: - if item := self._cache.get(single_arg): + def make_cache_key(self, args: tuple, kwargs: dict) -> Hashable: + bound = inspect.signature(self._method).bind(*args, **kwargs) + bound.apply_defaults() + + if self._cache_key_index is not None: + key_name = list(bound.arguments)[self._cache_key_index] + return bound.arguments[key_name] + + return (tuple(bound.arguments.items()),) + + async def __call__(self, *args: Any, **kwargs: Any) -> Any: + key = self.make_cache_key(args, kwargs) + + if item := self._cache.get(key): return item - if single_arg in self._inflight: - result = await self._inflight[single_arg] - return result + if key in self._inflight: + return await self._inflight[key] loop = asyncio.get_running_loop() future = loop.create_future() - self._inflight[single_arg] = future + self._inflight[key] = future try: - result = await self._method(single_arg) - self._cache.set(single_arg, result) + result = await self._method(*args, **kwargs) + self._cache.set(key, result) future.set_result(result) return result except Exception as e: - # Propagate errors future.set_exception(e) raise finally: - self._inflight.pop(single_arg, None) + self._inflight.pop(key, None) + + +class _CachedFetcherMethod: + """ + Helper class for using CachedFetcher with method caches (rather than functions) + """ + + def __init__(self, method, max_size: int, cache_key_index: int): + self.method = method + self.max_size = max_size + self.cache_key_index = cache_key_index + self._instances = {} + + def __get__(self, instance, owner): + if instance is None: + return self + + # Cache per-instance + if instance not in self._instances: + bound_method = self.method.__get__(instance, owner) + self._instances[instance] = CachedFetcher( + max_size=self.max_size, + method=bound_method, + cache_key_index=self.cache_key_index, + ) + return self._instances[instance] + + +def cached_fetcher(max_size: int, cache_key_index: int = 0): + """Wrapper for CachedFetcher. See example in CachedFetcher docstring.""" + + def wrapper(method): + return _CachedFetcherMethod(method, max_size, cache_key_index) + + return wrapper diff --git a/pyproject.toml b/pyproject.toml index 389ea1a..c3b6ab1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ keywords = ["substrate", "development", "bittensor"] dependencies = [ "wheel", - "asyncstdlib~=3.13.0", "bt-decode==v0.6.0", "scalecodec~=1.2.11", "websockets>=14.1", diff --git a/tests/unit_tests/test_cache.py b/tests/unit_tests/test_cache.py index 7844202..dddb2e8 100644 --- a/tests/unit_tests/test_cache.py +++ b/tests/unit_tests/test_cache.py @@ -13,18 +13,18 @@ async def test_cached_fetcher_fetches_and_caches(): fetcher = CachedFetcher(max_size=2, method=mock_method) # First call should trigger the method - result1 = await fetcher.execute("key1") + result1 = await fetcher("key1") assert result1 == "result_key1" mock_method.assert_awaited_once_with("key1") # Second call with the same key should use the cache - result2 = await fetcher.execute("key1") + result2 = await fetcher("key1") assert result2 == "result_key1" # Ensure the method was NOT called again assert mock_method.await_count == 1 # Third call with a new key triggers a method call - result3 = await fetcher.execute("key2") + result3 = await fetcher("key2") assert result3 == "result_key2" assert mock_method.await_count == 2 @@ -42,11 +42,11 @@ async def slow_method(x): fetcher = CachedFetcher(max_size=2, method=slow_method) # Start first request - task1 = asyncio.create_task(fetcher.execute("key1")) + task1 = asyncio.create_task(fetcher("key1")) await asyncio.sleep(0.1) # Let the task start and be inflight # Second request for the same key while the first is in-flight - task2 = asyncio.create_task(fetcher.execute("key1")) + task2 = asyncio.create_task(fetcher("key1")) await asyncio.sleep(0.1) # Release the inflight request @@ -65,7 +65,7 @@ async def error_method(x): fetcher = CachedFetcher(max_size=2, method=error_method) with pytest.raises(ValueError, match="Boom!"): - await fetcher.execute("key1") + await fetcher("key1") @pytest.mark.asyncio @@ -75,12 +75,12 @@ async def test_cached_fetcher_eviction(): fetcher = CachedFetcher(max_size=2, method=mock_method) # Fill cache - await fetcher.execute("key1") - await fetcher.execute("key2") + await fetcher("key1") + await fetcher("key2") assert list(fetcher._cache.cache.keys()) == list(fetcher._cache.cache.keys()) # Insert a new key to trigger eviction - await fetcher.execute("key3") + await fetcher("key3") # key1 should be evicted assert "key1" not in fetcher._cache.cache assert "key2" in fetcher._cache.cache