Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 81 additions & 47 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
get_next_id,
rng as random,
)
from async_substrate_interface.utils.cache import async_sql_lru_cache, CachedFetcher
from async_substrate_interface.utils.cache import (
async_sql_lru_cache,
cached_fetcher,
)
from async_substrate_interface.utils.decoding import (
_determine_if_old_runtime_call,
_bt_decode_to_dict_or_list,
Expand Down Expand Up @@ -794,12 +797,6 @@ def __init__(
self.registry_type_map = {}
self.type_id_to_name = {}
self._mock = _mock
self._block_hash_fetcher = CachedFetcher(512, self._get_block_hash)
self._parent_hash_fetcher = CachedFetcher(512, self._get_parent_block_hash)
self._runtime_info_fetcher = CachedFetcher(16, self._get_block_runtime_info)
self._runtime_version_for_fetcher = CachedFetcher(
512, self._get_block_runtime_version_for
)

async def __aenter__(self):
if not self._mock:
Expand Down Expand Up @@ -1044,35 +1041,7 @@ async def init_runtime(
if not runtime:
self.last_block_hash = block_hash

runtime_block_hash = await self.get_parent_block_hash(block_hash)

runtime_info = await self.get_block_runtime_info(runtime_block_hash)

metadata, (metadata_v15, registry) = await asyncio.gather(
self.get_block_metadata(block_hash=runtime_block_hash, decode=True),
self._load_registry_at_block(block_hash=runtime_block_hash),
)
if metadata is None:
# does this ever happen?
raise SubstrateRequestException(
f"No metadata for block '{runtime_block_hash}'"
)
logger.debug(
f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node"
)

runtime = Runtime(
chain=self.chain,
runtime_config=self.runtime_config,
metadata=metadata,
type_registry=self.type_registry,
metadata_v15=metadata_v15,
runtime_info=runtime_info,
registry=registry,
)
self.runtime_cache.add_item(
runtime_version=runtime_version, runtime=runtime
)
runtime = await self.get_runtime_for_version(runtime_version, block_hash)

self.load_runtime(runtime)

Expand All @@ -1086,6 +1055,51 @@ async def init_runtime(
self.ss58_format = ss58_prefix_constant
return runtime

@cached_fetcher(max_size=16, cache_key_index=0)
async def get_runtime_for_version(
self, runtime_version: int, block_hash: Optional[str] = None
) -> Runtime:
"""
Retrieves the `Runtime` for a given runtime version at a given block hash.
Args:
runtime_version: version of the runtime (from `get_block_runtime_version_for`)
block_hash: hash of the block to query

Returns:
Runtime object for the given runtime version
"""
return await self._get_runtime_for_version(runtime_version, block_hash)

async def _get_runtime_for_version(
self, runtime_version: int, block_hash: Optional[str] = None
) -> Runtime:
runtime_block_hash = await self.get_parent_block_hash(block_hash)
runtime_info, metadata, (metadata_v15, registry) = await asyncio.gather(
self.get_block_runtime_info(runtime_block_hash),
self.get_block_metadata(block_hash=runtime_block_hash, decode=True),
self._load_registry_at_block(block_hash=runtime_block_hash),
)
if metadata is None:
# does this ever happen?
raise SubstrateRequestException(
f"No metadata for block '{runtime_block_hash}'"
)
logger.debug(
f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node"
)

runtime = Runtime(
chain=self.chain,
runtime_config=self.runtime_config,
metadata=metadata,
type_registry=self.type_registry,
metadata_v15=metadata_v15,
runtime_info=runtime_info,
registry=registry,
)
self.runtime_cache.add_item(runtime_version=runtime_version, runtime=runtime)
return runtime

async def create_storage_key(
self,
pallet: str,
Expand Down Expand Up @@ -1921,10 +1935,19 @@ async def get_metadata(self, block_hash=None) -> MetadataV15:

return runtime.metadata_v15

async def get_parent_block_hash(self, block_hash):
return await self._parent_hash_fetcher.execute(block_hash)
@cached_fetcher(max_size=512)
async def get_parent_block_hash(self, block_hash) -> str:
"""
Retrieves the block hash of the parent of the given block hash
Args:
block_hash: hash of the block to query

Returns:
Hash of the parent block hash, or the original block hash (if it has not parent)
"""
return await self._get_parent_block_hash(block_hash)

async def _get_parent_block_hash(self, block_hash):
async def _get_parent_block_hash(self, block_hash) -> str:
block_header = await self.rpc_request("chain_getHeader", [block_hash])

if block_header["result"] is None:
Expand Down Expand Up @@ -1967,25 +1990,27 @@ async def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any:
"Unknown error occurred during retrieval of events"
)

@cached_fetcher(max_size=16)
async def get_block_runtime_info(self, block_hash: str) -> dict:
return await self._runtime_info_fetcher.execute(block_hash)
"""
Retrieve the runtime info of given block_hash
"""
return await self._get_block_runtime_info(block_hash)

get_block_runtime_version = get_block_runtime_info

async def _get_block_runtime_info(self, block_hash: str) -> dict:
"""
Retrieve the runtime info of given block_hash
"""
response = await self.rpc_request("state_getRuntimeVersion", [block_hash])
return response.get("result")

@cached_fetcher(max_size=512)
async def get_block_runtime_version_for(self, block_hash: str):
return await self._runtime_version_for_fetcher.execute(block_hash)

async def _get_block_runtime_version_for(self, block_hash: str):
"""
Retrieve the runtime version of the parent of a given block_hash
"""
return await self._get_block_runtime_version_for(block_hash)

async def _get_block_runtime_version_for(self, block_hash: str):
parent_block_hash = await self.get_parent_block_hash(block_hash)
runtime_info = await self.get_block_runtime_info(parent_block_hash)
if runtime_info is None:
Expand Down Expand Up @@ -2296,8 +2321,17 @@ async def rpc_request(
else:
raise SubstrateRequestException(result[payload_id][0])

@cached_fetcher(max_size=512)
async def get_block_hash(self, block_id: int) -> str:
return await self._block_hash_fetcher.execute(block_id)
"""
Retrieves the hash of the specified block number
Args:
block_id: block number

Returns:
Hash of the block
"""
return await self._get_block_hash(block_id)

async def _get_block_hash(self, block_id: int) -> str:
return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"]
Expand Down
133 changes: 114 additions & 19 deletions async_substrate_interface/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import asyncio
import inspect
from collections import OrderedDict
import functools
import logging
import os
import pickle
import sqlite3
from pathlib import Path
from typing import Callable, Any

import asyncstdlib as a

from typing import Callable, Any, Awaitable, Hashable, Optional

USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False
CACHE_LOCATION = (
Expand All @@ -19,6 +18,8 @@
else ":memory:"
)

logger = logging.getLogger("async_substrate_interface")


def _ensure_dir():
path = Path(CACHE_LOCATION).parent
Expand Down Expand Up @@ -70,7 +71,7 @@ def _retrieve_from_cache(c, table_name, key, chain):
if result is not None:
return pickle.loads(result[0])
except (pickle.PickleError, sqlite3.Error) as e:
print(f"Cache error: {str(e)}")
logger.exception("Cache error", exc_info=e)
pass


Expand All @@ -82,7 +83,7 @@ def _insert_into_cache(c, conn, table_name, key, result, chain):
)
conn.commit()
except (pickle.PickleError, sqlite3.Error) as e:
print(f"Cache error: {str(e)}")
logger.exception("Cache error", exc_info=e)
pass


Expand Down Expand Up @@ -128,7 +129,7 @@ def inner(self, *args, **kwargs):

def async_sql_lru_cache(maxsize=None):
def decorator(func):
@a.lru_cache(maxsize=maxsize)
@cached_fetcher(max_size=maxsize)
async def inner(self, *args, **kwargs):
c, conn, table_name, key, result, chain, local_chain = (
_shared_inner_fn_logic(func, self, args, kwargs)
Expand All @@ -147,6 +148,10 @@ async def inner(self, *args, **kwargs):


class LRUCache:
"""
Basic Least-Recently-Used Cache, with simple methods `set` and `get`
"""

def __init__(self, max_size: int):
self.max_size = max_size
self.cache = OrderedDict()
Expand All @@ -167,31 +172,121 @@ def get(self, key):


class CachedFetcher:
def __init__(self, max_size: int, method: Callable):
self._inflight: dict[int, asyncio.Future] = {}
"""
Async caching class that allows the standard async LRU cache system, but also allows for concurrent
asyncio calls (with the same args) to use the same result of a single call.

This should only be used for asyncio calls where the result is immutable.

Concept and usage:
```
async def fetch(self, block_hash: str) -> str:
return await some_resource(block_hash)

a1, a2, b = await asyncio.gather(fetch("a"), fetch("a"), fetch("b"))
```

Here, you are making three requests, but you really only need to make two I/O requests
(one for "a", one for "b"), and while you wouldn't typically make a request like this directly, it's very
common in using this library to inadvertently make these requests y gathering multiple resources that depend
on the calls like this under the hood.

By using

```
@cached_fetcher(max_size=512)
async def fetch(self, block_hash: str) -> str:
return await some_resource(block_hash)

a1, a2, b = await asyncio.gather(fetch("a"), fetch("a"), fetch("b"))
```

You are only making two I/O calls, and a2 will simply use the result of a1 when it lands.
"""

def __init__(
self,
max_size: int,
method: Callable[..., Awaitable[Any]],
cache_key_index: Optional[int] = 0,
):
"""
Args:
max_size: max size of the cache (in items)
method: the function to cache
cache_key_index: if the method takes multiple args, this is the index of that cache key in the args list
(default is the first arg). By setting this to `None`, it will use all args as the cache key.
"""
self._inflight: dict[Hashable, asyncio.Future] = {}
self._method = method
self._cache = LRUCache(max_size=max_size)
self._cache_key_index = cache_key_index

async def execute(self, single_arg: Any) -> str:
if item := self._cache.get(single_arg):
def make_cache_key(self, args: tuple, kwargs: dict) -> Hashable:
bound = inspect.signature(self._method).bind(*args, **kwargs)
bound.apply_defaults()

if self._cache_key_index is not None:
key_name = list(bound.arguments)[self._cache_key_index]
return bound.arguments[key_name]

return (tuple(bound.arguments.items()),)

async def __call__(self, *args: Any, **kwargs: Any) -> Any:
key = self.make_cache_key(args, kwargs)

if item := self._cache.get(key):
return item

if single_arg in self._inflight:
result = await self._inflight[single_arg]
return result
if key in self._inflight:
return await self._inflight[key]

loop = asyncio.get_running_loop()
future = loop.create_future()
self._inflight[single_arg] = future
self._inflight[key] = future

try:
result = await self._method(single_arg)
self._cache.set(single_arg, result)
result = await self._method(*args, **kwargs)
self._cache.set(key, result)
future.set_result(result)
return result
except Exception as e:
# Propagate errors
future.set_exception(e)
raise
finally:
self._inflight.pop(single_arg, None)
self._inflight.pop(key, None)


class _CachedFetcherMethod:
"""
Helper class for using CachedFetcher with method caches (rather than functions)
"""

def __init__(self, method, max_size: int, cache_key_index: int):
self.method = method
self.max_size = max_size
self.cache_key_index = cache_key_index
self._instances = {}

def __get__(self, instance, owner):
if instance is None:
return self

# Cache per-instance
if instance not in self._instances:
bound_method = self.method.__get__(instance, owner)
self._instances[instance] = CachedFetcher(
max_size=self.max_size,
method=bound_method,
cache_key_index=self.cache_key_index,
)
return self._instances[instance]


def cached_fetcher(max_size: int, cache_key_index: int = 0):
"""Wrapper for CachedFetcher. See example in CachedFetcher docstring."""

def wrapper(method):
return _CachedFetcherMethod(method, max_size, cache_key_index)

return wrapper
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ keywords = ["substrate", "development", "bittensor"]

dependencies = [
"wheel",
"asyncstdlib~=3.13.0",
"bt-decode==v0.6.0",
"scalecodec~=1.2.11",
"websockets>=14.1",
Expand Down
Loading
Loading