diff --git a/docs/guides/trinity/architecture.rst b/docs/guides/trinity/architecture.rst index 3821a688a6..5b6f773b40 100644 --- a/docs/guides/trinity/architecture.rst +++ b/docs/guides/trinity/architecture.rst @@ -109,8 +109,8 @@ that all database reads and writes are done by a single process. Networking Process ------------------ -The networking process is what kicks of the peer to peer communication, starts the syncing -process and also serves the JSON-RPC API. It does so by running an instance of +The networking process is what kicks of the peer to peer communication and starts the syncing +process. It does so by running an instance of :func:`~trinity.nodes.base.Node` in an event loop. Notice that the instance of :func:`~trinity.nodes.base.Node` has access to the APIs that the @@ -119,3 +119,14 @@ connections to other peers, starts of the syncing process but will call APIs tha the database processes when it comes to actual importing of blocks or reading and writing of other things from the database. +The networking process also host an instance of the +:class:`~trinity.extensibility.plugin_manager.PluginManager` to run plugins that need to deeply +integrate with the networking process (Further reading: +:doc:`Writing Plugins`). + +Plugin Processes +---------------- + +Apart from running these three core processes, there may be additional processes for plugins that +run in isolated processes. Isolated plugins are explained in depth in the +:doc:`Writing Plugins` guide. \ No newline at end of file diff --git a/docs/guides/trinity/index.rst b/docs/guides/trinity/index.rst index ad63941e19..a19412f8c8 100644 --- a/docs/guides/trinity/index.rst +++ b/docs/guides/trinity/index.rst @@ -9,4 +9,5 @@ This section aims to provide hands-on guides to demonstrate how to use Trinity. :caption: Guides quickstart - architecture \ No newline at end of file + architecture + writing_plugins \ No newline at end of file diff --git a/docs/guides/trinity/writing_plugins.rst b/docs/guides/trinity/writing_plugins.rst new file mode 100644 index 0000000000..6777e71b6c --- /dev/null +++ b/docs/guides/trinity/writing_plugins.rst @@ -0,0 +1,263 @@ +Writing Plugins +=============== + +Trinity aims to be a highly flexible Ethereum node to support lots of different use cases +beyond just participating in the regular networking traffic. + +To support this goal, Trinity allows developers to create plugins that hook into the system to +extend its functionality. In fact, Trinity dogfoods its Plugin API in the sense that several +built-in features are written as plugins that just happen to be shipped among the rest of the core +modules. For instance, the JSON-RPC API, the Transaction Pool as well as the ``trinity attach`` +command that provides an interactive REPL with `Web3` integration are all built as plugins. + +Trinity tries to follow the practice: If something can be written as a plugin, it should be written +as a plugin. + + +What can plugins do? +~~~~~~~~~~~~~~~~~~~~ + +Plugin support in Trinity is still very new and the API hasn't stabilized yet. That said, plugins +are already pretty powerful and are only becoming more so as the APIs of the underlying services +improve over time. + +Here's a list of functionality that is currently provided by plugins: + +- JSON-RPC API +- Transaction Pool +- EthStats Reporting +- Interactive REPL with Web3 integration +- Crash Recovery Command + + +Understanding the different plugin categories +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are currently three different types of plugins that we'll all cover in this guide. + +- Plugins that overtake and redefine the entire ``trinity`` command +- Plugins that spawn their own new isolated process +- Plugins that run in the shared `networking` process + + + +Plugins that redefine the Trinity process +----------------------------------------- + +This is the simplest category of plugins as it doesn't really *hook* into the Trinity process but +hijacks it entirely instead. We may be left wonderering: Why would one want to do that? + +The only reason to write such a plugin is to execute some code that we want to group under the +``trinity`` command. A great example for such a plugin is the ``trinity attach`` command that gives +us a REPL attached to a running Trinity instance. This plugin could have easily be written as a +standalone program and associated with a command such as ``trinity-attach``. However, using a +subcommand ``attach`` is the more idiomatic approach and this type of plugin gives us simple way +to develop exactly that. + +We build this kind of plugin by subclassing from +:class:`~trinity.extensibility.plugin.BaseMainProcessPlugin`. A detailed example will follow soon. + + +Plugins that spawn their own new isolated process +------------------------------------------------- + +Of course, if all what plugins could do is to hijack the `trinity` command, there wouldn't be +much room to actually extend the *runtime functionality* of Trinity. If we want to create plugins +that boot with and run alongside the main node activity, we need to write a different kind of +plugin. These type of plugins can respond to events such as a peers connecting/disconnecting and +can access information that is only available within the running application. + +The JSON-RPC API is a great example as it exposes information such as the current count +of connected peers which is live information that can only be accessed by talking to other parts +of the application at runtime. + +This is the default type of plugin we want to build if: + +- we want to execute logic **together** with the command that boots Trinity (as opposed + to executing it in a separate command) +- we want to execute logic that integrates with parts of Trinity that can only be accessed at + runtime (as opposed to e.g. just reading things from the database) + +We build this kind of plugin subclassing from +:class:`~trinity.extensibility.plugin.BaseIsolatedPlugin`. A detailed example will follow soon. + + +Plugins that run inside the networking process +---------------------------------------------- + +If the previous category sounded as if it could handle every possible use case, it's because it's +actually meant to. In reality though, not all internal APIs yet work well across process +boundaries. In practice, this means that sometimes we want to make sure that a plugin runs in the +same process as the rest of the networking code. + +.. warning:: + The need to run plugins in the networking process is declining as the internals of Trinity become + more and more multi-process friendly over time. While it isn't entirely clear yet, there's a fair + chance this type of plugin will become obsolete at some point and may eventually be removed. + + We should only choose this type of plugin category if what we are trying to build cannot be built + with a :class:`~trinity.extensibility.plugin.BaseIsolatedPlugin`. + +We build this kind of plugin subclassing from +:class:`~trinity.extensibility.plugin.BaseAsyncStopPlugin`. A detailed example will follow soon. + + +The plugin lifecycle +~~~~~~~~~~~~~~~~~~~~ + +Plugins can be in one of the following status at a time: + +- ``NOT_READY`` +- ``READY`` +- ``STARTED`` +- ``STOPPED`` + +The current status of a plugin is also reflected in the +:meth:`~trinity.extensibility.plugin.BasePlugin.status` property. + +.. note:: + + Strictly speaking, there's also a special state that only applies to the + :class:`~trinity.extensibility.plugin.BaseMainProcessPlugin` which comes into effect when such a + plugin hijacks the Trinity process entirely. That being said, in that case, the resulting process + is in fact something entirely different than Trinity and the whole plugin infrastruture doesn't + even continue to exist from the moment on where that plugin takes over the Trinity process. This + is why we do not list it as an actual state of the regular plugin lifecycle. + +Plugin state: ``NOT_READY`` +--------------------------- + +Every plugin starts out being in the ``NOT_READY`` state. This state begins with the instantiation +of the plugin and lasts until the :meth:`~trinity.extensibility.plugin.BasePlugin.on_ready` hook +was called which happens as soon the core infrastructure of Trinity is ready. + +Plugin state: ``READY`` +----------------------- + +After Trinity has finished setting up the core infrastructure, every plugin has its +:class:`~trinity.extensibility.plugin.PluginContext` set and +:meth:`~trinity.extensibility.plugin.BasePlugin.on_ready` is called. At this point the plugin has +access to important information such as the parsed arguments or the +:class:`~trinity.config.TrinityConfig`. It also has access to the central event bus via its +:meth:`~trinity.extensibility.plugin.BasePlugin.event_bus` property which enables the plugin to +communicate with other parts of the application including other plugins. + +Plugin state: ``STARTED`` +------------------------- + +A plugin is in the ``STARTED`` state after the +:meth:`~trinity.extensibility.plugin.BasePlugin.start` method was called. Plugins call this method +themselves whenever they want to start which may be based on some condition like Trinity being +started with certain parameters or some event being propagated on the central event bus. + +.. note:: + Calling :meth:`~trinity.extensibility.plugin.BasePlugin.start` while the plugin is in the + ``NOT_READY`` state or when it is already in ``STARTED`` will cause an exception to be raised. + + +Plugin state: ``STOPPED`` +------------------------- + +A plugin is in the ``STOPPED`` state after the +:meth:`~trinity.extensibility.plugin.BasePlugin.stop` method was called and finished any tear down +work. + +Defining plugins +~~~~~~~~~~~~~~~~ + +We define a plugin by deriving from either +:class:`~trinity.extensibility.plugin.BaseMainProcessPlugin`, +:class:`~trinity.extensibility.plugin.BaseIsolatedPlugin` or +:class:`~trinity.extensibility.plugin.BaseAsyncStopPlugin` depending on the kind of plugin that we +intend to write. For now, we'll stick to :class:`~trinity.extensibility.plugin.BaseIsolatedPlugin` +which is the most commonly used plugin category. + +Every plugin needs to overwrite ``name`` so voilĂ , here's our first plugin! + +.. literalinclude:: ../../../trinity/plugins/examples/peer_count_reporter/plugin.py + :language: python + :start-after: --START CLASS-- + :end-before: def configure_parser + +Of course that doesn't do anything useful yet, bear with us. + +Configuring Command Line Arguments +---------------------------------- + +More often than not we want to have control over if or when a plugin should start. Adding +command-line arguments that are specific to such a plugin, which we then check, validate, and act +on, is a good way to deal with that. Implementing +:meth:`~trinity.extensibility.plugin.BasePlugin.configure_parser` enables us to do exactly that. + +This method is called when Trinity starts and bootstraps the plugin system, in other words, +**before** the start of any plugin. It is passed a :class:`~argparse.ArgumentParser` as well as a +:class:`~argparse._SubParsersAction` which allows it to amend the configuration of Trinity's +command line arguments in many different ways. + +For example, here we are adding a boolean flag ``--report-peer-count`` to Trinity. + +.. literalinclude:: ../../../trinity/plugins/examples/peer_count_reporter/plugin.py + :language: python + :pyobject: PeerCountReporterPlugin.configure_parser + +To be clear, this does not yet cause our plugin to automatically start if ``--report-peer-count`` +is passed, it simply changes the parser to be aware of such flag and hence allows us to check for +its existence later. + +.. note:: + + For a more advanced example, that also configures a subcommand, checkout the ``trinity attach`` + plugin. + + +Defining a plugins starting point +--------------------------------- + +Every plugin needs to have a well defined starting point. The exact mechanics slightly differ +in case of a :class:`~trinity.extensibility.plugin.BaseMainProcessPlugin` but remain fairly similar +for the other types of plugins which we'll be focussing on for now. + +Plugins need to implement the :meth:`~trinity.extensibility.plugin.BasePlugin.do_start` method +to define their own bootstrapping logic. This logic may involve setting up event listeners, running +code in a loop or any other kind of action. + +.. warning:: + + Technically, there's nothing preventing a plugin from performing starting logic in the + :meth:`~trinity.extensibility.plugin.BasePlugin.on_ready` hook. However, doing that is an anti + pattern as the plugin infrastructure won't know about the running plugin, can't propagate the + :class:`~trinity.extensibility.events.PluginStartedEvent` and the plugin won't be properly shut + down with Trinity if the node closes. + +Causing a plugin to start +------------------------- + +As we've read in the previous section not all plugins should run at any point in time. In fact, the +circumstances under which we want a plugin to begin its work may vary from plugin to plugin. + +We may want a plugin to only start running if: + +- a certain (combination) of command line arguments was given +- another plugin or group of plugins started +- a certain number of connected peers was exceeded / undershot +- a certain block number was reached +- ... + +Hence, to actually start a plugin, the plugin needs to invoke the +:meth:`~trinity.extensibility.plugin.BasePlugin.start` method at any moment when it is in its +``READY`` state. + +Communication pattern +~~~~~~~~~~~~~~~~~~~~~ + +Coming soon: Spoiler: Plugins can communicate with other parts of the application or even other +plugins via the central event bus. + +Making plugins discoverable +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Coming soon. + +.. warning:: + **Wait?! This is it? No! This is draft version of the plugin guide as small DEVCON IV gitft. + This will turn into a much more detailed guide shortly after the devcon craze is over.** diff --git a/eth/_warnings.py b/eth/_warnings.py index 6ecaf6b323..366d25aa75 100644 --- a/eth/_warnings.py +++ b/eth/_warnings.py @@ -1,10 +1,11 @@ from contextlib import contextmanager +from typing import Iterator import warnings # TODO: drop once https://github.com/cython/cython/issues/1720 is resolved @contextmanager -def catch_and_ignore_import_warning(): +def catch_and_ignore_import_warning() -> Iterator[None]: with warnings.catch_warnings(): warnings.simplefilter('ignore', category=ImportWarning) yield diff --git a/eth/chains/base.py b/eth/chains/base.py index cca020d992..512a6f1113 100644 --- a/eth/chains/base.py +++ b/eth/chains/base.py @@ -24,6 +24,7 @@ import logging from eth_typing import ( + Address, BlockNumber, Hash32, ) @@ -241,9 +242,14 @@ def create_transaction(self, *args: Any, **kwargs: Any) -> BaseTransaction: raise NotImplementedError("Chain classes must implement this method") @abstractmethod - def create_unsigned_transaction(self, - *args: Any, - **kwargs: Any) -> BaseUnsignedTransaction: + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> BaseUnsignedTransaction: raise NotImplementedError("Chain classes must implement this method") @abstractmethod @@ -573,12 +579,24 @@ def create_transaction(self, *args: Any, **kwargs: Any) -> BaseTransaction: return self.get_vm().create_transaction(*args, **kwargs) def create_unsigned_transaction(self, - *args: Any, - **kwargs: Any) -> BaseUnsignedTransaction: + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> BaseUnsignedTransaction: """ Passthrough helper to the current VM class. """ - return self.get_vm().create_unsigned_transaction(*args, **kwargs) + return self.get_vm().create_unsigned_transaction( + nonce=nonce, + gas_price=gas_price, + gas=gas, + to=to, + value=value, + data=data, + ) # # Execution API @@ -592,7 +610,8 @@ def get_transaction_result( This is referred to as a `call()` in web3. """ with self.get_vm(at_header).state_in_temp_block() as state: - computation = state.costless_execute_transaction(transaction) + # Ignore is to not bleed the SpoofTransaction deeper into the code base + computation = state.costless_execute_transaction(transaction) # type: ignore computation.raise_if_error() return computation.output diff --git a/eth/db/chain.py b/eth/db/chain.py index c7e9d41c89..a50535c955 100644 --- a/eth/db/chain.py +++ b/eth/db/chain.py @@ -156,7 +156,7 @@ def get(self, key: bytes) -> bytes: raise NotImplementedError("ChainDB classes must implement this method") @abstractmethod - def persist_trie_data_dict(self, trie_data_dict: Dict[bytes, bytes]) -> None: + def persist_trie_data_dict(self, trie_data_dict: Dict[Hash32, bytes]) -> None: raise NotImplementedError("ChainDB classes must implement this method") @@ -463,7 +463,7 @@ def get(self, key: bytes) -> bytes: """ return self.db[key] - def persist_trie_data_dict(self, trie_data_dict: Dict[bytes, bytes]) -> None: + def persist_trie_data_dict(self, trie_data_dict: Dict[Hash32, bytes]) -> None: """ Store raw trie data to db from a dict """ diff --git a/eth/rlp/headers.py b/eth/rlp/headers.py index c7ae982ca6..a1be902d22 100644 --- a/eth/rlp/headers.py +++ b/eth/rlp/headers.py @@ -203,7 +203,7 @@ def from_parent(cls, return header def create_execution_context( - self, prev_hashes: Union[Tuple[bytes], Tuple[bytes, bytes]]) -> ExecutionContext: + self, prev_hashes: Tuple[Hash32, ...]) -> ExecutionContext: return ExecutionContext( coinbase=self.coinbase, diff --git a/eth/rlp/logs.py b/eth/rlp/logs.py index a19fae7150..adac2e3733 100644 --- a/eth/rlp/logs.py +++ b/eth/rlp/logs.py @@ -4,6 +4,8 @@ binary, ) +from typing import List + from .sedes import ( address, int32, @@ -17,7 +19,7 @@ class Log(rlp.Serializable): ('data', binary) ] - def __init__(self, address: bytes, topics: bytes, data: bytes) -> None: + def __init__(self, address: bytes, topics: List[int], data: bytes) -> None: super().__init__(address, topics, data) @property diff --git a/eth/rlp/transactions.py b/eth/rlp/transactions.py index 168489360b..14684465ad 100644 --- a/eth/rlp/transactions.py +++ b/eth/rlp/transactions.py @@ -2,9 +2,6 @@ ABC, abstractmethod ) -from typing import ( - Any, -) import rlp from rlp.sedes import ( @@ -17,7 +14,9 @@ ) from eth_hash.auto import keccak - +from eth_keys.datatypes import ( + PrivateKey +) from eth_utils import ( ValidationError, ) @@ -150,7 +149,14 @@ def get_message_for_signing(self) -> bytes: @classmethod @abstractmethod - def create_unsigned_transaction(self, *args: Any, **kwargs: Any) -> 'BaseTransaction': + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> 'BaseUnsignedTransaction': """ Create an unsigned transaction. """ @@ -171,7 +177,7 @@ class BaseUnsignedTransaction(rlp.Serializable, BaseTransactionMethods, ABC): # API that must be implemented by all Transaction subclasses. # @abstractmethod - def as_signed_transaction(self, private_key: bytes) -> 'BaseTransaction': + def as_signed_transaction(self, private_key: PrivateKey) -> 'BaseTransaction': """ Return a version of this transaction which has been signed using the provided `private_key` diff --git a/eth/tools/_utils/hashing.py b/eth/tools/_utils/hashing.py index f5172ccd0d..2276a3f977 100644 --- a/eth/tools/_utils/hashing.py +++ b/eth/tools/_utils/hashing.py @@ -4,6 +4,7 @@ from typing import ( Iterable, + List, Tuple, ) @@ -14,7 +15,7 @@ from eth.rlp.logs import Log -def hash_log_entries(log_entries: Iterable[Tuple[bytes, bytes, bytes]]) -> Hash32: +def hash_log_entries(log_entries: Iterable[Tuple[bytes, List[int], bytes]]) -> Hash32: """ Helper function for computing the RLP hash of the logs from transaction execution. diff --git a/eth/tools/_utils/normalization.py b/eth/tools/_utils/normalization.py index d5b44a42c2..6d0b8e5949 100644 --- a/eth/tools/_utils/normalization.py +++ b/eth/tools/_utils/normalization.py @@ -1,4 +1,3 @@ -import binascii import functools from typing import ( @@ -8,10 +7,7 @@ cast, Dict, Iterable, - List, Mapping, - Sequence, - Tuple, Union, ) @@ -19,9 +15,7 @@ assoc_in, compose, concat, - curry, identity, - merge, ) import cytoolz.curried @@ -41,7 +35,6 @@ is_text, to_bytes, to_canonical_address, - to_dict, ValidationError, ) import eth_utils.curried @@ -60,7 +53,6 @@ GeneralState, IntConvertible, Normalizer, - TransactionDict, TransactionNormalizer, ) @@ -327,214 +319,3 @@ def state_definition_to_dict(state_definition: GeneralState) -> AccountState: normalize_environment = dict_options_normalizer([ normalize_main_environment, ]) - - -# -# Fixture Normalizers -# -def normalize_unsigned_transaction(transaction: TransactionDict, - indexes: Dict[str, Any]) -> TransactionDict: - - normalized = normalize_transaction_group(transaction) - return merge(normalized, { - # Dynamic key access not yet allowed with TypedDict - # https://github.com/python/mypy/issues/5359 - transaction_key: normalized[transaction_key][indexes[index_key]] # type: ignore - for transaction_key, index_key in [ - ("gasLimit", "gas"), - ("value", "value"), - ("data", "data"), - ] - if index_key in indexes - }) - - -def normalize_account_state(account_state: AccountState) -> AccountState: - return { - to_canonical_address(address): { - 'balance': to_int(state['balance']), - 'code': decode_hex(state['code']), - 'nonce': to_int(state['nonce']), - 'storage': { - to_int(slot): big_endian_to_int(decode_hex(value)) - for slot, value in state['storage'].items() - }, - } for address, state in account_state.items() - } - - -@to_dict -def normalize_post_state(post_state: Dict[str, Any]) -> Iterable[Tuple[str, bytes]]: - yield 'hash', decode_hex(post_state['hash']) - if 'logs' in post_state: - yield 'logs', decode_hex(post_state['logs']) - - -@curry -def normalize_statetest_fixture(fixture: Dict[str, Any], - fork: str, - post_state_index: int) -> Dict[str, Any]: - - post_state = fixture['post'][fork][post_state_index] - - normalized_fixture = { - 'env': normalize_environment(fixture['env']), - 'pre': normalize_account_state(fixture['pre']), - 'post': normalize_post_state(post_state), - 'transaction': normalize_unsigned_transaction( - fixture['transaction'], - post_state['indexes'], - ), - } - - return normalized_fixture - - -def normalize_exec(exec_params: Dict[str, Any]) -> Dict[str, Any]: - return { - 'origin': to_canonical_address(exec_params['origin']), - 'address': to_canonical_address(exec_params['address']), - 'caller': to_canonical_address(exec_params['caller']), - 'value': to_int(exec_params['value']), - 'data': decode_hex(exec_params['data']), - 'gas': to_int(exec_params['gas']), - 'gasPrice': to_int(exec_params['gasPrice']), - } - - -def normalize_callcreates(callcreates: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]: - return [ - { - 'data': decode_hex(created_call['data']), - 'destination': ( - to_canonical_address(created_call['destination']) - if created_call['destination'] - else CREATE_CONTRACT_ADDRESS - ), - 'gasLimit': to_int(created_call['gasLimit']), - 'value': to_int(created_call['value']), - } for created_call in callcreates - ] - - -@to_dict -def normalize_vmtest_fixture(fixture: Dict[str, Any]) -> Iterable[Tuple[str, Any]]: - yield 'env', normalize_environment(fixture['env']) - yield 'exec', normalize_exec(fixture['exec']) - yield 'pre', normalize_account_state(fixture['pre']) - - if 'post' in fixture: - yield 'post', normalize_account_state(fixture['post']) - - if 'callcreates' in fixture: - yield 'callcreates', normalize_callcreates(fixture['callcreates']) - - if 'gas' in fixture: - yield 'gas', to_int(fixture['gas']) - - if 'out' in fixture: - yield 'out', decode_hex(fixture['out']) - - if 'logs' in fixture: - yield 'logs', decode_hex(fixture['logs']) - - -def normalize_signed_transaction(transaction: Dict[str, Any]) -> Dict[str, Any]: - return { - 'data': robust_decode_hex(transaction['data']), - 'gasLimit': to_int(transaction['gasLimit']), - 'gasPrice': to_int(transaction['gasPrice']), - 'nonce': to_int(transaction['nonce']), - 'r': to_int(transaction['r']), - 's': to_int(transaction['s']), - 'v': to_int(transaction['v']), - 'to': decode_hex(transaction['to']), - 'value': to_int(transaction['value']), - } - - -@curry -def normalize_transactiontest_fixture(fixture: Dict[str, Any], fork: str) -> Dict[str, Any]: - - normalized_fixture = {} - - fork_data = fixture[fork] - - try: - normalized_fixture['rlp'] = decode_hex(fixture['rlp']) - except binascii.Error: - normalized_fixture['rlpHex'] = fixture['rlp'] - - if "sender" in fork_data: - normalized_fixture['sender'] = fork_data['sender'] - - if "hash" in fork_data: - normalized_fixture['hash'] = fork_data['hash'] - - return normalized_fixture - - -def normalize_block_header(header: Dict[str, Any]) -> Dict[str, Any]: - normalized_header = { - 'bloom': big_endian_to_int(decode_hex(header['bloom'])), - 'coinbase': to_canonical_address(header['coinbase']), - 'difficulty': to_int(header['difficulty']), - 'extraData': decode_hex(header['extraData']), - 'gasLimit': to_int(header['gasLimit']), - 'gasUsed': to_int(header['gasUsed']), - 'hash': decode_hex(header['hash']), - 'mixHash': decode_hex(header['mixHash']), - 'nonce': decode_hex(header['nonce']), - 'number': to_int(header['number']), - 'parentHash': decode_hex(header['parentHash']), - 'receiptTrie': decode_hex(header['receiptTrie']), - 'stateRoot': decode_hex(header['stateRoot']), - 'timestamp': to_int(header['timestamp']), - 'transactionsTrie': decode_hex(header['transactionsTrie']), - 'uncleHash': decode_hex(header['uncleHash']), - } - if 'blocknumber' in header: - normalized_header['blocknumber'] = to_int(header['blocknumber']) - if 'chainname' in header: - normalized_header['chainname'] = header['chainname'] - if 'chainnetwork' in header: - normalized_header['chainnetwork'] = header['chainnetwork'] - return normalized_header - - -def normalize_block(block: Dict[str, Any]) -> Dict[str, Any]: - normalized_block = {} - - try: - normalized_block['rlp'] = decode_hex(block['rlp']) - except ValueError as err: - normalized_block['rlp_error'] = err - - if 'blockHeader' in block: - normalized_block['blockHeader'] = normalize_block_header(block['blockHeader']) - if 'transactions' in block: - normalized_block['transactions'] = [ - normalize_signed_transaction(transaction) - for transaction - in block['transactions'] - ] - return normalized_block - - -def normalize_blockchain_fixtures(fixture: Dict[str, Any]) -> Dict[str, Any]: - normalized_fixture = { - 'blocks': [normalize_block(block_fixture) for block_fixture in fixture['blocks']], - 'genesisBlockHeader': normalize_block_header(fixture['genesisBlockHeader']), - 'lastblockhash': decode_hex(fixture['lastblockhash']), - 'pre': normalize_account_state(fixture['pre']), - 'postState': normalize_account_state(fixture['postState']), - 'network': fixture['network'], - } - - if 'sealEngine' in fixture: - normalized_fixture['sealEngine'] = fixture['sealEngine'] - - if 'genesisRLP' in fixture: - normalized_fixture['genesisRLP'] = decode_hex(fixture['genesisRLP']) - - return normalized_fixture diff --git a/eth/tools/fixtures/fillers/_utils.py b/eth/tools/fixtures/fillers/_utils.py index 018fa9c885..81c2d79042 100644 --- a/eth/tools/fixtures/fillers/_utils.py +++ b/eth/tools/fixtures/fillers/_utils.py @@ -72,7 +72,7 @@ def calc_state_root(state: AccountState, account_db_class: Type[BaseAccountDB]) def generate_random_keypair() -> Tuple[bytes, Address]: key_object = keys.PrivateKey(pad32(int_to_big_endian(random.getrandbits(8 * 32)))) - return key_object.to_bytes(), key_object.public_key.to_canonical_address() + return key_object.to_bytes(), Address(key_object.public_key.to_canonical_address()) def generate_random_address() -> Address: diff --git a/eth/tools/fixtures/fillers/vm.py b/eth/tools/fixtures/fillers/vm.py index 8dd8fcc11f..39844c27cd 100644 --- a/eth/tools/fixtures/fillers/vm.py +++ b/eth/tools/fixtures/fillers/vm.py @@ -2,6 +2,7 @@ Any, Dict, Iterable, + List, Tuple, Union, ) @@ -28,7 +29,7 @@ def fill_vm_test( call_creates: Any=None, gas_price: Union[int, str]=None, gas_remaining: Union[int, str]=0, - logs: Iterable[Tuple[bytes, bytes, bytes]]=None, + logs: Iterable[Tuple[bytes, List[int], bytes]]=None, output: bytes=b"") -> Dict[str, Dict[str, Any]]: test_name = get_test_name(filler) diff --git a/eth/tools/fixtures/normalization.py b/eth/tools/fixtures/normalization.py index 9448a2df97..6fa025f3e4 100644 --- a/eth/tools/fixtures/normalization.py +++ b/eth/tools/fixtures/normalization.py @@ -1,46 +1,28 @@ import binascii -import functools from typing import ( Any, - AnyStr, - Callable, Dict, Iterable, List, - Mapping, Sequence, Tuple, ) from cytoolz import ( - assoc_in, - compose, - concat, curry, - identity, merge, ) -import cytoolz.curried - -from eth_typing import ( - Address, -) from eth_utils import ( - apply_formatters_to_dict, big_endian_to_int, decode_hex, - is_0x_prefixed, is_bytes, is_hex, - is_integer, - is_string, is_text, to_bytes, to_canonical_address, to_dict, - ValidationError, ) import eth_utils.curried @@ -48,18 +30,14 @@ CREATE_CONTRACT_ADDRESS, ) -from eth.tools._utils.mappings import ( - deep_merge, - is_cleanly_mergable, -) from eth.tools._utils.normalization import ( - normalize_transaction_group + normalize_environment, + normalize_transaction_group, + to_int, ) from eth.typing import ( AccountState, - GeneralState, - Normalizer, TransactionDict, ) @@ -67,27 +45,6 @@ # # Primitives # -@functools.lru_cache(maxsize=1024) -def normalize_int(value: Any) -> int: - """ - Robust to integer conversion, handling hex values, string representations, - and special cases like `0x`. - """ - if is_integer(value): - return value - elif is_bytes(value): - return big_endian_to_int(value) - elif is_hex(value) and is_0x_prefixed(value): - if len(value) == 2: - return 0 - else: - return int(value, 16) - elif is_string(value): - return int(value) - else: - raise TypeError("Unsupported type: Got `{0}`".format(type(value))) - - def normalize_bytes(value: Any) -> bytes: if is_bytes(value): return value @@ -97,202 +54,9 @@ def normalize_bytes(value: Any) -> bytes: raise TypeError("Value must be either a string or bytes object") -@functools.lru_cache(maxsize=1024) -def to_int(value: str) -> int: - """ - Robust to integer conversion, handling hex values, string representations, - and special cases like `0x`. - """ - if is_0x_prefixed(value): - if len(value) == 2: - return 0 - else: - return int(value, 16) - else: - return int(value) - - -@functools.lru_cache(maxsize=128) -def normalize_to_address(value: AnyStr) -> Address: - if value: - return to_canonical_address(value) - else: - return CREATE_CONTRACT_ADDRESS - - robust_decode_hex = eth_utils.curried.hexstr_if_str(to_bytes) -# -# Containers -# -def dict_normalizer(formatters: Dict[Any, Callable[..., Any]], - required: Iterable[Any]=None, - optional: Iterable[Any]=None) -> Normalizer: - - all_keys = set(formatters.keys()) - - if required is None and optional is None: - required_set_form = all_keys - elif required is not None and optional is not None: - raise ValueError("Both required and optional keys specified") - elif required is not None: - required_set_form = set(required) - elif optional is not None: - required_set_form = all_keys - set(optional) - - def normalizer(d: Dict[Any, Any]) -> Dict[str, Any]: - keys = set(d.keys()) - missing_keys = required_set_form - keys - superfluous_keys = keys - all_keys - if missing_keys: - raise KeyError("Missing required keys: {}".format(", ".join(missing_keys))) - if superfluous_keys: - raise KeyError("Superfluous keys: {}".format(", ".join(superfluous_keys))) - - return apply_formatters_to_dict(formatters, d) - - return normalizer - - -def dict_options_normalizer(normalizers: Iterable[Normalizer]) -> Normalizer: - - def normalize(d: Dict[Any, Any]) -> Dict[str, Any]: - first_exception = None - for normalizer in normalizers: - try: - normalized = normalizer(d) - except KeyError as e: - if not first_exception: - first_exception = e - else: - return normalized - assert first_exception is not None - raise first_exception - - return normalize - - -# -# Composition -# -def state_definition_to_dict(state_definition: GeneralState) -> AccountState: - """Convert a state definition to the canonical dict form. - - State can either be defined in the canonical form, or as a list of sub states that are then - merged to one. Sub states can either be given as dictionaries themselves, or as tuples where - the last element is the value and all others the keys for this value in the nested state - dictionary. Example: - - ``` - [ - ("0xaabb", "balance", 3), - ("0xaabb", "storage", { - 4: 5, - }), - "0xbbcc", { - "balance": 6, - "nonce": 7 - } - ] - ``` - """ - if isinstance(state_definition, Mapping): - state_dict = state_definition - elif isinstance(state_definition, Iterable): - state_dicts = [ - assoc_in( - {}, - state_item[:-1], - state_item[-1] - ) if not isinstance(state_item, Mapping) else state_item - for state_item - in state_definition - ] - if not is_cleanly_mergable(*state_dicts): - raise ValidationError("Some state item is defined multiple times") - state_dict = deep_merge(*state_dicts) - else: - assert TypeError("State definition must either be a mapping or a sequence") - - seen_keys = set(concat(d.keys() for d in state_dict.values())) - bad_keys = seen_keys - set(["balance", "nonce", "storage", "code"]) - if bad_keys: - raise ValidationError( - "State definition contains the following invalid account fields: {}".format( - ", ".join(bad_keys) - ) - ) - - return state_dict - - -normalize_storage = compose( - cytoolz.curried.keymap(normalize_int), - cytoolz.curried.valmap(normalize_int), -) - - -normalize_state = compose( - cytoolz.curried.keymap(to_canonical_address), - cytoolz.curried.valmap(dict_normalizer({ - "balance": normalize_int, - "code": normalize_bytes, - "nonce": normalize_int, - "storage": normalize_storage - }, required=[])), - eth_utils.curried.apply_formatter_if( - lambda s: isinstance(s, Iterable) and not isinstance(s, Mapping), - state_definition_to_dict - ), -) - - -normalize_execution = dict_normalizer({ - "address": to_canonical_address, - "origin": to_canonical_address, - "caller": to_canonical_address, - "value": normalize_int, - "data": normalize_bytes, - "gasPrice": normalize_int, - "gas": normalize_int, -}) - - -normalize_networks = identity - - -normalize_call_create_item = dict_normalizer({ - "data": normalize_bytes, - "destination": to_canonical_address, - "gasLimit": normalize_int, - "value": normalize_int, -}) -normalize_call_creates = eth_utils.curried.apply_formatter_to_array(normalize_call_create_item) - -normalize_log_item = dict_normalizer({ - "address": to_canonical_address, - "topics": eth_utils.curried.apply_formatter_to_array(normalize_int), - "data": normalize_bytes, -}) -normalize_logs = eth_utils.curried.apply_formatter_to_array(normalize_log_item) - - -normalize_main_environment = dict_normalizer({ - "currentCoinbase": to_canonical_address, - "previousHash": normalize_bytes, - "currentNumber": normalize_int, - "currentDifficulty": normalize_int, - "currentGasLimit": normalize_int, - "currentTimestamp": normalize_int, -}, optional=["previousHash"]) - - -normalize_environment = dict_options_normalizer([ - normalize_main_environment, -]) - - # # Fixture Normalizers # diff --git a/eth/utils/db.py b/eth/utils/db.py index 07b82896b8..1dd3a149d0 100644 --- a/eth/utils/db.py +++ b/eth/utils/db.py @@ -2,6 +2,10 @@ TYPE_CHECKING, ) +from eth_typing import ( + Hash32, +) + from eth.db.account import ( BaseAccountDB, ) @@ -24,7 +28,7 @@ def get_parent_header(block_header: BlockHeader, db: 'BaseChainDB') -> BlockHead return db.get_block_header_by_hash(block_header.parent_hash) -def get_block_header_by_hash(block_hash: BlockHeader, db: 'BaseChainDB') -> BlockHeader: +def get_block_header_by_hash(block_hash: Hash32, db: 'BaseChainDB') -> BlockHeader: """ Returns the header for the parent block. """ diff --git a/eth/utils/transactions.py b/eth/utils/transactions.py index 1003ca68b3..fec60ae5b9 100644 --- a/eth/utils/transactions.py +++ b/eth/utils/transactions.py @@ -11,6 +11,7 @@ ValidationError, ) from eth.typing import ( + Address, VRS, ) from eth.utils.numeric import ( @@ -91,7 +92,7 @@ def validate_transaction_signature(transaction: BaseTransaction) -> None: raise ValidationError("Invalid Signature") -def extract_transaction_sender(transaction: BaseTransaction) -> bytes: +def extract_transaction_sender(transaction: BaseTransaction) -> Address: if is_eip_155_signed_transaction(transaction): if is_even(transaction.v): v = 28 @@ -108,4 +109,4 @@ def extract_transaction_sender(transaction: BaseTransaction) -> bytes: message = transaction.get_message_for_signing() public_key = signature.recover_public_key_from_msg(message) sender = public_key.to_canonical_address() - return sender + return Address(sender) diff --git a/eth/validation.py b/eth/validation.py index e04b287223..bdec8fe8fc 100644 --- a/eth/validation.py +++ b/eth/validation.py @@ -1,5 +1,21 @@ import functools +from typing import ( + Any, + Dict, + Iterable, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +from eth_typing import ( + Address, + Hash32, +) + from eth_utils import ( ValidationError, ) @@ -24,22 +40,25 @@ UINT_256_MAX, ) +if TYPE_CHECKING: + from eth.vm.base import BaseVM # noqa: F401 + -def validate_is_bytes(value, title="Value"): +def validate_is_bytes(value: bytes, title: str="Value") -> None: if not isinstance(value, bytes): raise ValidationError( "{title} must be a byte string. Got: {0}".format(type(value), title=title) ) -def validate_is_integer(value, title="Value"): +def validate_is_integer(value: Union[int, bool], title: str="Value") -> None: if not isinstance(value, int) or isinstance(value, bool): raise ValidationError( "{title} must be a an integer. Got: {0}".format(type(value), title=title) ) -def validate_length(value, length, title="Value"): +def validate_length(value: Sequence[Any], length: int, title: str="Value") -> None: if not len(value) == length: raise ValidationError( "{title} must be of length {0}. Got {1} of length {2}".format( @@ -51,7 +70,7 @@ def validate_length(value, length, title="Value"): ) -def validate_length_lte(value, maximum_length, title="Value"): +def validate_length_lte(value: Sequence[Any], maximum_length: int, title: str="Value") -> None: if len(value) > maximum_length: raise ValidationError( "{title} must be of length less than or equal to {0}. " @@ -64,7 +83,7 @@ def validate_length_lte(value, maximum_length, title="Value"): ) -def validate_gte(value, minimum, title="Value"): +def validate_gte(value: int, minimum: int, title: str="Value") -> None: if value < minimum: raise ValidationError( "{title} {0} is not greater than or equal to {1}".format( @@ -76,7 +95,7 @@ def validate_gte(value, minimum, title="Value"): validate_is_integer(value) -def validate_gt(value, minimum, title="Value"): +def validate_gt(value: int, minimum: int, title: str="Value") -> None: if value <= minimum: raise ValidationError( "{title} {0} is not greater than {1}".format(value, minimum, title=title) @@ -84,7 +103,7 @@ def validate_gt(value, minimum, title="Value"): validate_is_integer(value, title=title) -def validate_lte(value, maximum, title="Value"): +def validate_lte(value: int, maximum: int, title: str="Value") -> None: if value > maximum: raise ValidationError( "{title} {0} is not less than or equal to {1}".format( @@ -96,7 +115,7 @@ def validate_lte(value, maximum, title="Value"): validate_is_integer(value, title=title) -def validate_lt(value, maximum, title="Value"): +def validate_lt(value: int, maximum: int, title: str="Value") -> None: if value >= maximum: raise ValidationError( "{title} {0} is not less than {1}".format(value, maximum, title=title) @@ -104,28 +123,28 @@ def validate_lt(value, maximum, title="Value"): validate_is_integer(value, title=title) -def validate_canonical_address(value, title="Value"): +def validate_canonical_address(value: Address, title: str="Value") -> None: if not isinstance(value, bytes) or not len(value) == 20: raise ValidationError( "{title} {0} is not a valid canonical address".format(value, title=title) ) -def validate_multiple_of(value, multiple_of, title="Value"): +def validate_multiple_of(value: int, multiple_of: int, title: str="Value") -> None: if not value % multiple_of == 0: raise ValidationError( "{title} {0} is not a multiple of {1}".format(value, multiple_of, title=title) ) -def validate_is_boolean(value, title="Value"): +def validate_is_boolean(value: bool, title: str="Value") -> None: if not isinstance(value, bool): raise ValidationError( "{title} must be an boolean. Got type: {0}".format(type(value), title=title) ) -def validate_word(value, title="Value"): +def validate_word(value: Hash32, title: str="Value") -> None: if not isinstance(value, bytes): raise ValidationError( "{title} is not a valid word. Must be of bytes type: Got: {0}".format( @@ -142,7 +161,7 @@ def validate_word(value, title="Value"): ) -def validate_uint256(value, title="Value"): +def validate_uint256(value: int, title: str="Value") -> None: if not isinstance(value, int) or isinstance(value, bool): raise ValidationError( "{title} must be an integer: Got: {0}".format( @@ -166,7 +185,7 @@ def validate_uint256(value, title="Value"): ) -def validate_stack_item(value): +def validate_stack_item(value: Union[int, bytes]) -> None: if isinstance(value, bytes) and len(value) <= 32: return elif isinstance(value, int) and 0 <= value <= UINT_256_MAX: @@ -181,7 +200,7 @@ def validate_stack_item(value): validate_lt_secpk1n2 = functools.partial(validate_lte, maximum=SECPK1_N // 2 - 1) -def validate_unique(values, title="Value"): +def validate_unique(values: Iterable[Any], title: str="Value") -> None: if not isdistinct(values): duplicates = pipe( values, @@ -198,19 +217,19 @@ def validate_unique(values, title="Value"): ) -def validate_block_number(block_number, title="Block Number"): +def validate_block_number(block_number: int, title: str="Block Number") -> None: validate_is_integer(block_number, title) validate_gte(block_number, 0, title) -def validate_vm_block_numbers(vm_block_numbers): +def validate_vm_block_numbers(vm_block_numbers: Iterable[int]) -> None: validate_unique(vm_block_numbers, title="Block Number set") for block_number in vm_block_numbers: validate_block_number(block_number) -def validate_vm_configuration(vm_configuration): +def validate_vm_configuration(vm_configuration: Tuple[Tuple[int, Type['BaseVM']], ...]) -> None: validate_vm_block_numbers(tuple( block_number for block_number, _ @@ -218,7 +237,7 @@ def validate_vm_configuration(vm_configuration): )) -def validate_gas_limit(gas_limit, parent_gas_limit): +def validate_gas_limit(gas_limit: int, parent_gas_limit: int) -> None: if gas_limit < GAS_LIMIT_MINIMUM: raise ValidationError("Gas limit {0} is below minimum {1}".format( gas_limit, GAS_LIMIT_MINIMUM)) @@ -245,7 +264,7 @@ def validate_gas_limit(gas_limit, parent_gas_limit): } -def validate_header_params_for_configuration(header_params): +def validate_header_params_for_configuration(header_params: Dict[str, Any]) -> None: extra_fields = set(header_params.keys()).difference(ALLOWED_HEADER_FIELDS) if extra_fields: raise ValidationError( diff --git a/eth/vm/base.py b/eth/vm/base.py index bd75bd5b5d..71cfcf7774 100644 --- a/eth/vm/base.py +++ b/eth/vm/base.py @@ -7,6 +7,11 @@ import functools import logging from typing import ( + Any, + Iterable, + Iterator, + Optional, + Tuple, Type, ) @@ -16,6 +21,11 @@ BloomFilter, ) +from eth_typing import ( + Address, + Hash32, +) + from eth_utils import ( to_tuple, ValidationError, @@ -45,6 +55,10 @@ from eth.rlp.sedes import ( int32, ) +from eth.rlp.transactions import ( + BaseTransaction, + BaseUnsignedTransaction, +) from eth.utils.datatypes import ( Configurable, ) @@ -62,7 +76,8 @@ from eth.vm.message import ( Message, ) -from eth.vm.state import BaseState # noqa: F401 +from eth.vm.state import BaseState +from eth.vm.computation import BaseComputation class BaseVM(Configurable, ABC): @@ -74,11 +89,11 @@ class BaseVM(Configurable, ABC): @property @abstractmethod - def state(self): + def state(self) -> BaseState: raise NotImplementedError("VM classes must implement this property") @abstractmethod - def __init__(self, header, chaindb): + def __init__(self, header: BlockHeader, chaindb: BaseChainDB) -> None: pass # @@ -86,31 +101,38 @@ def __init__(self, header, chaindb): # @property @abstractmethod - def logger(self): + def logger(self) -> logging.Logger: raise NotImplementedError("VM classes must implement this method") # # Execution # @abstractmethod - def apply_transaction(self, header, transaction): + def apply_transaction(self, + header: BlockHeader, + transaction: BaseTransaction + ) -> Tuple[BlockHeader, Receipt, BaseComputation]: raise NotImplementedError("VM classes must implement this method") @abstractmethod def execute_bytecode(self, - origin, - gas_price, - gas, - to, - sender, - value, - data, - code, - code_address=None): + origin: Address, + gas_price: int, + gas: int, + to: Address, + sender: Address, + value: int, + data: bytes, + code: bytes, + code_address: Address=None) -> BaseComputation: raise NotImplementedError("VM classes must implement this method") @abstractmethod - def make_receipt(self, base_header, transaction, computation, state): + def make_receipt(self, + base_header: BlockHeader, + transaction: BaseTransaction, + computation: BaseComputation, + state: BaseState) -> Receipt: """ Generate the receipt resulting from applying the transaction. @@ -127,26 +149,30 @@ def make_receipt(self, base_header, transaction, computation, state): # Mining # @abstractmethod - def import_block(self, block): + def import_block(self, block: BaseBlock) -> BaseBlock: raise NotImplementedError("VM classes must implement this method") @abstractmethod - def mine_block(self, *args, **kwargs): + def mine_block(self, *args: Any, **kwargs: Any) -> BaseBlock: raise NotImplementedError("VM classes must implement this method") @abstractmethod - def set_block_transactions(self, base_block, new_header, transactions, receipts): + def set_block_transactions(self, + base_block: BaseBlock, + new_header: BlockHeader, + transactions: Tuple[BaseTransaction, ...], + receipts: Tuple[Receipt, ...]) -> BaseBlock: raise NotImplementedError("VM classes must implement this method") # # Finalization # @abstractmethod - def finalize_block(self, block): + def finalize_block(self, block: BaseBlock) -> BaseBlock: raise NotImplementedError("VM classes must implement this method") @abstractmethod - def pack_block(self, block, *args, **kwargs): + def pack_block(self, block: BaseBlock, *args: Any, **kwargs: Any) -> BaseBlock: raise NotImplementedError("VM classes must implement this method") # @@ -154,7 +180,7 @@ def pack_block(self, block, *args, **kwargs): # @classmethod @abstractmethod - def compute_difficulty(cls, parent_header, timestamp): + def compute_difficulty(cls, parent_header: BlockHeader, timestamp: int) -> int: """ Compute the difficulty for a block header. @@ -164,7 +190,7 @@ def compute_difficulty(cls, parent_header, timestamp): raise NotImplementedError("VM classes must implement this method") @abstractmethod - def configure_header(self, **header_params): + def configure_header(self, **header_params: Any) -> BlockHeader: """ Setup the current header with the provided parameters. This can be used to set fields like the gas limit or timestamp to value different @@ -174,7 +200,9 @@ def configure_header(self, **header_params): @classmethod @abstractmethod - def create_header_from_parent(cls, parent_header, **header_params): + def create_header_from_parent(cls, + parent_header: BlockHeader, + **header_params: Any) -> BlockHeader: """ Creates and initializes a new block header from the provided `parent_header`. @@ -186,7 +214,9 @@ def create_header_from_parent(cls, parent_header, **header_params): # @classmethod @abstractmethod - def generate_block_from_parent_header_and_coinbase(cls, parent_header, coinbase): + def generate_block_from_parent_header_and_coinbase(cls, + parent_header: BlockHeader, + coinbase: Address) -> BaseBlock: raise NotImplementedError("VM classes must implement this method") @classmethod @@ -219,7 +249,10 @@ def get_nephew_reward(cls) -> int: @classmethod @abstractmethod - def get_prev_hashes(cls, last_block_hash, chaindb): + @to_tuple + def get_prev_hashes(cls, + last_block_hash: Hash32, + chaindb: BaseChainDB) -> Optional[Iterable[Hash32]]: raise NotImplementedError("VM classes must implement this method") @staticmethod @@ -237,17 +270,24 @@ def get_uncle_reward(block_number: int, uncle: BaseBlock) -> int: # Transactions # @abstractmethod - def create_transaction(self, *args, **kwargs): + def create_transaction(self, *args: Any, **kwargs: Any) -> BaseTransaction: raise NotImplementedError("VM classes must implement this method") @classmethod @abstractmethod - def create_unsigned_transaction(cls, *args, **kwargs): + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> BaseUnsignedTransaction: raise NotImplementedError("VM classes must implement this method") @classmethod @abstractmethod - def get_transaction_class(cls): + def get_transaction_class(cls) -> Type[BaseTransaction]: raise NotImplementedError("VM classes must implement this method") # @@ -259,7 +299,7 @@ def validate_receipt(self, receipt: Receipt) -> None: raise NotImplementedError("VM classes must implement this method") @abstractmethod - def validate_block(self, block): + def validate_block(self, block: BaseBlock) -> None: raise NotImplementedError("VM classes must implement this method") @classmethod @@ -269,7 +309,9 @@ def validate_header( raise NotImplementedError("VM classes must implement this method") @abstractmethod - def validate_transaction_against_header(self, base_header, transaction): + def validate_transaction_against_header(self, + base_header: BlockHeader, + transaction: BaseTransaction) -> None: """ Validate that the given transaction is valid to apply to the given header. @@ -296,12 +338,12 @@ def validate_uncle( # @classmethod @abstractmethod - def get_state_class(cls): + def get_state_class(cls) -> Type[BaseState]: raise NotImplementedError("VM classes must implement this method") @abstractmethod @contextlib.contextmanager - def state_in_temp_block(self): + def state_in_temp_block(self) -> Iterator[BaseState]: raise NotImplementedError("VM classes must implement this method") @@ -320,12 +362,12 @@ class VM(BaseVM): _state = None - def __init__(self, header, chaindb): + def __init__(self, header: BlockHeader, chaindb: BaseChainDB) -> None: self.chaindb = chaindb self.block = self.get_block_class().from_header(header=header, chaindb=self.chaindb) @property - def state(self): + def state(self) -> BaseState: if self._state is None: self._state = self.get_state_class()( db=self.chaindb.db, @@ -338,13 +380,16 @@ def state(self): # Logging # @property - def logger(self): + def logger(self) -> logging.Logger: return logging.getLogger('eth.vm.base.VM.{0}'.format(self.__class__.__name__)) # # Execution # - def apply_transaction(self, header, transaction): + def apply_transaction(self, + header: BlockHeader, + transaction: BaseTransaction + ) -> Tuple[BlockHeader, Receipt, BaseComputation]: """ Apply the transaction to the current block. This is a wrapper around :func:`~eth.vm.state.State.apply_transaction` with some extra orchestration logic. @@ -366,16 +411,16 @@ def apply_transaction(self, header, transaction): return new_header, receipt, computation def execute_bytecode(self, - origin, - gas_price, - gas, - to, - sender, - value, - data, - code, - code_address=None, - ): + origin: Address, + gas_price: int, + gas: int, + to: Address, + sender: Address, + value: int, + data: bytes, + code: bytes, + code_address: Address=None, + ) -> BaseComputation: """ Execute raw bytecode in the context of the current state of the virtual machine. @@ -407,7 +452,9 @@ def execute_bytecode(self, transaction_context, ) - def apply_all_transactions(self, transactions, base_header): + def apply_all_transactions(self, + transactions: Tuple[BaseTransaction, ...], + base_header: BlockHeader) -> Tuple[BlockHeader, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: # noqa: E501 """ Determine the results of applying all transactions to the base header. This does *not* update the current block or header of the VM. @@ -440,12 +487,15 @@ def apply_all_transactions(self, transactions, base_header): receipts.append(receipt) computations.append(computation) - return result_header, receipts, computations + receipts_tuple = tuple(receipts) + computations_tuple = tuple(computations) + + return result_header, receipts_tuple, computations_tuple # # Mining # - def import_block(self, block): + def import_block(self, block: BaseBlock) -> BaseBlock: """ Import the given block to the chain. """ @@ -488,7 +538,7 @@ def import_block(self, block): return self.mine_block() - def mine_block(self, *args, **kwargs): + def mine_block(self, *args: Any, **kwargs: Any) -> BaseBlock: """ Mine the current block. Proxies to self.pack_block method. """ @@ -504,7 +554,11 @@ def mine_block(self, *args, **kwargs): return final_block - def set_block_transactions(self, base_block, new_header, transactions, receipts): + def set_block_transactions(self, + base_block: BaseBlock, + new_header: BlockHeader, + transactions: Tuple[BaseTransaction, ...], + receipts: Tuple[Receipt, ...]) -> BaseBlock: tx_root_hash, tx_kv_nodes = make_trie_root_and_nodes(transactions) self.chaindb.persist_trie_data_dict(tx_kv_nodes) @@ -523,7 +577,7 @@ def set_block_transactions(self, base_block, new_header, transactions, receipts) # # Finalization # - def finalize_block(self, block): + def finalize_block(self, block: BaseBlock) -> BaseBlock: """ Perform any finalization steps like awarding the block mining reward. """ @@ -553,7 +607,7 @@ def finalize_block(self, block): return block.copy(header=block.header.copy(state_root=self.state.state_root)) - def pack_block(self, block, *args, **kwargs): + def pack_block(self, block: BaseBlock, *args: Any, **kwargs: Any) -> BaseBlock: """ Pack block for mining. @@ -596,7 +650,9 @@ def pack_block(self, block, *args, **kwargs): # Blocks # @classmethod - def generate_block_from_parent_header_and_coinbase(cls, parent_header, coinbase): + def generate_block_from_parent_header_and_coinbase(cls, + parent_header: BlockHeader, + coinbase: Address) -> BaseBlock: """ Generate block from parent header and coinbase. """ @@ -626,7 +682,9 @@ def get_block_class(cls) -> Type[BaseBlock]: @classmethod @functools.lru_cache(maxsize=32) @to_tuple - def get_prev_hashes(cls, last_block_hash, chaindb): + def get_prev_hashes(cls, + last_block_hash: Hash32, + chaindb: BaseChainDB) -> Optional[Iterable[Hash32]]: if last_block_hash == GENESIS_PARENT_HASH: return @@ -640,7 +698,7 @@ def get_prev_hashes(cls, last_block_hash, chaindb): break @property - def previous_hashes(self): + def previous_hashes(self) -> Optional[Tuple[Hash32, ...]]: """ Convenience API for accessing the previous 255 block hashes. """ @@ -649,21 +707,35 @@ def previous_hashes(self): # # Transactions # - def create_transaction(self, *args, **kwargs): + def create_transaction(self, *args: Any, **kwargs: Any) -> BaseTransaction: """ Proxy for instantiating a signed transaction for this VM. """ return self.get_transaction_class()(*args, **kwargs) @classmethod - def create_unsigned_transaction(cls, *args, **kwargs): + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> 'BaseUnsignedTransaction': """ Proxy for instantiating an unsigned transaction for this VM. """ - return cls.get_transaction_class().create_unsigned_transaction(*args, **kwargs) + return cls.get_transaction_class().create_unsigned_transaction( + nonce=nonce, + gas_price=gas_price, + gas=gas, + to=to, + value=value, + data=data + ) @classmethod - def get_transaction_class(cls): + def get_transaction_class(cls) -> Type[BaseTransaction]: """ Return the class that this VM uses for transactions. """ @@ -688,7 +760,7 @@ def validate_receipt(cls, receipt: Receipt) -> None: "filter.".format(topic_idx, log_idx) ) - def validate_block(self, block): + def validate_block(self, block: BaseBlock) -> None: """ Validate the the given block. """ @@ -739,8 +811,10 @@ def validate_block(self, block): ) @classmethod - def validate_header( - cls, header: BlockHeader, parent_header: BlockHeader, check_seal: bool = True) -> None: + def validate_header(cls, + header: BlockHeader, + parent_header: BlockHeader, + check_seal: bool = True) -> None: """ :raise eth.exceptions.ValidationError: if the header is not valid """ @@ -785,7 +859,7 @@ def validate_seal(cls, header: BlockHeader) -> None: header.mix_hash, header.nonce, header.difficulty) @classmethod - def validate_uncle(cls, block, uncle, uncle_parent): + def validate_uncle(cls, block: BaseBlock, uncle: BaseBlock, uncle_parent: BaseBlock) -> None: """ Validate the given uncle in the context of the given block. """ @@ -811,7 +885,7 @@ def validate_uncle(cls, block, uncle, uncle_parent): # State # @classmethod - def get_state_class(cls): + def get_state_class(cls) -> Type[BaseState]: """ Return the class that this VM uses for states. """ @@ -821,7 +895,7 @@ def get_state_class(cls): return cls._state_class @contextlib.contextmanager - def state_in_temp_block(self): + def state_in_temp_block(self) -> Iterator[BaseState]: header = self.block.header temp_block = self.generate_block_from_parent_header_and_coinbase(header, header.coinbase) prev_hashes = (header.hash, ) + self.previous_hashes diff --git a/eth/vm/code_stream.py b/eth/vm/code_stream.py index e628ca76d5..c19830e3d3 100644 --- a/eth/vm/code_stream.py +++ b/eth/vm/code_stream.py @@ -54,11 +54,11 @@ def peek(self) -> int: return next_opcode @property - def pc(self): + def pc(self) -> int: return self.stream.tell() @pc.setter - def pc(self, value): + def pc(self, value: int) -> None: self.stream.seek(min(value, len(self))) @contextlib.contextmanager diff --git a/eth/vm/computation.py b/eth/vm/computation.py index 26bfa62a22..86bdf912cc 100644 --- a/eth/vm/computation.py +++ b/eth/vm/computation.py @@ -12,10 +12,11 @@ Iterator, List, Tuple, + Union, ) from eth_typing import ( - Address + Address, ) from eth.constants import ( @@ -27,7 +28,7 @@ VMError, ) from eth.tools.logging import ( - TraceLogger + TraceLogger, ) from eth.utils.datatypes import ( Configurable, @@ -114,7 +115,7 @@ class BaseComputation(Configurable, ABC): accounts_to_delete = None # type: Dict[bytes, bytes] # VM configuration - opcodes = None # type: Dict[int, Opcode] + opcodes = None # type: Dict[int, Any] _precompiles = None # type: Dict[Address, Callable[['BaseComputation'], Any]] logger = cast(TraceLogger, logging.getLogger('eth.vm.computation.Computation')) @@ -198,7 +199,7 @@ def should_erase_return_data(self) -> bool: # def prepare_child_message(self, gas: int, - to: bytes, + to: Address, value: int, data: bytes, code: bytes, @@ -291,7 +292,9 @@ def refund_gas(self, amount: int) -> None: """ return self._gas_meter.refund_gas(amount) - def stack_pop(self, num_items=1, type_hint=None): + def stack_pop(self, num_items: int=1, type_hint: str=None) -> Any: + # TODO: Needs to be replaced with + # `Union[int, bytes, Tuple[Union[int, bytes], ...]]` if done properly """ Pop and return a number of items equal to ``num_items`` from the stack. ``type_hint`` can be either ``'uint256'`` or ``'bytes'``. The return value @@ -303,7 +306,7 @@ def stack_pop(self, num_items=1, type_hint=None): """ return self._stack.pop(num_items, type_hint) - def stack_push(self, value): + def stack_push(self, value: Union[int, bytes]) -> None: """ Push ``value`` onto the stack. @@ -311,13 +314,13 @@ def stack_push(self, value): """ return self._stack.push(value) - def stack_swap(self, position): + def stack_swap(self, position: int) -> None: """ Swap the item on the top of the stack with the item at ``position``. """ return self._stack.swap(position) - def stack_dup(self, position): + def stack_dup(self, position: int) -> None: """ Duplicate the stack item at ``position`` and pushes it onto the stack. """ @@ -575,7 +578,7 @@ def precompiles(self) -> Dict[Address, Callable[['BaseComputation'], Any]]: else: return self._precompiles - def get_opcode_fn(self, opcode): + def get_opcode_fn(self, opcode: int) -> Opcode: try: return self.opcodes[opcode] except KeyError: diff --git a/eth/vm/execution_context.py b/eth/vm/execution_context.py index 68f450cc98..ca453b28c0 100644 --- a/eth/vm/execution_context.py +++ b/eth/vm/execution_context.py @@ -1,3 +1,11 @@ +from typing import Tuple + +from eth_typing import ( + Address, + Hash32, +) + + class ExecutionContext: _coinbase = None @@ -9,12 +17,12 @@ class ExecutionContext: def __init__( self, - coinbase, - timestamp, - block_number, - difficulty, - gas_limit, - prev_hashes): + coinbase: Address, + timestamp: int, + block_number: int, + difficulty: int, + gas_limit: int, + prev_hashes: Tuple[Hash32, ...]) -> None: self._coinbase = coinbase self._timestamp = timestamp self._block_number = block_number @@ -23,25 +31,25 @@ def __init__( self._prev_hashes = prev_hashes @property - def coinbase(self): + def coinbase(self) -> Address: return self._coinbase @property - def timestamp(self): + def timestamp(self) -> int: return self._timestamp @property - def block_number(self): + def block_number(self) -> int: return self._block_number @property - def difficulty(self): + def difficulty(self) -> int: return self._difficulty @property - def gas_limit(self): + def gas_limit(self) -> int: return self._gas_limit @property - def prev_hashes(self): + def prev_hashes(self) -> Tuple[Hash32, ...]: return self._prev_hashes diff --git a/eth/vm/forks/byzantium/__init__.py b/eth/vm/forks/byzantium/__init__.py index f0346511be..f1e9e7f703 100644 --- a/eth/vm/forks/byzantium/__init__.py +++ b/eth/vm/forks/byzantium/__init__.py @@ -15,12 +15,15 @@ MAX_UNCLE_DEPTH, ) from eth.rlp.blocks import BaseBlock # noqa: F401 +from eth.rlp.headers import BlockHeader from eth.rlp.receipts import Receipt +from eth.rlp.transactions import BaseTransaction from eth.validation import ( validate_lte, ) from eth.vm.forks.spurious_dragon import SpuriousDragonVM from eth.vm.forks.frontier import make_frontier_receipt +from eth.vm.computation import BaseComputation from eth.vm.state import BaseState # noqa: F401 from .blocks import ByzantiumBlock @@ -37,7 +40,10 @@ from .state import ByzantiumState -def make_byzantium_receipt(base_header, transaction, computation, state): +def make_byzantium_receipt(base_header: BlockHeader, + transaction: BaseTransaction, + computation: BaseComputation, + state: BaseState) -> Receipt: frontier_receipt = make_frontier_receipt(base_header, transaction, computation, state) if computation.is_error: @@ -49,7 +55,7 @@ def make_byzantium_receipt(base_header, transaction, computation, state): @curry -def get_uncle_reward(block_reward, block_number, uncle): +def get_uncle_reward(block_reward: int, block_number: int, uncle: BaseBlock) -> int: block_number_delta = block_number - uncle.block_number validate_lte(block_number_delta, MAX_UNCLE_DEPTH) return (8 - block_number_delta) * block_reward // 8 @@ -70,10 +76,10 @@ class ByzantiumVM(SpuriousDragonVM): _state_class = ByzantiumState # type: Type[BaseState] # Methods - create_header_from_parent = staticmethod(create_byzantium_header_from_parent) - compute_difficulty = staticmethod(compute_byzantium_difficulty) + create_header_from_parent = staticmethod(create_byzantium_header_from_parent) # type: ignore + compute_difficulty = staticmethod(compute_byzantium_difficulty) # type: ignore configure_header = configure_byzantium_header - make_receipt = staticmethod(make_byzantium_receipt) + make_receipt = staticmethod(make_byzantium_receipt) # type: ignore # Separated into two steps due to mypy bug of staticmethod. # https://github.com/python/mypy/issues/5530 get_uncle_reward = get_uncle_reward(EIP649_BLOCK_REWARD) @@ -93,5 +99,5 @@ def validate_receipt(cls, receipt: Receipt) -> None: ) @staticmethod - def get_block_reward(): + def get_block_reward() -> int: return EIP649_BLOCK_REWARD diff --git a/eth/vm/forks/byzantium/opcodes.py b/eth/vm/forks/byzantium/opcodes.py index 37d83bdb49..986f2abf6f 100644 --- a/eth/vm/forks/byzantium/opcodes.py +++ b/eth/vm/forks/byzantium/opcodes.py @@ -3,6 +3,11 @@ from cytoolz import merge +from typing import ( + Any, + Callable, +) + from eth import constants from eth.exceptions import ( @@ -10,6 +15,7 @@ ) from eth.vm import mnemonics from eth.vm import opcode_values +from eth.vm.computation import BaseComputation from eth.vm.forks.tangerine_whistle.constants import ( GAS_CALL_EIP150, GAS_SELFDESTRUCT_EIP150 @@ -26,9 +32,9 @@ from eth.vm.forks.spurious_dragon.opcodes import SPURIOUS_DRAGON_OPCODES -def ensure_no_static(opcode_fn): +def ensure_no_static(opcode_fn: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(opcode_fn) - def inner(computation): + def inner(computation: BaseComputation) -> Callable[..., Any]: if computation.msg.is_static: raise WriteProtection("Cannot modify state while inside of a STATICCALL context") return opcode_fn(computation) diff --git a/eth/vm/forks/byzantium/transactions.py b/eth/vm/forks/byzantium/transactions.py index 5450583b52..95554aa710 100644 --- a/eth/vm/forks/byzantium/transactions.py +++ b/eth/vm/forks/byzantium/transactions.py @@ -1,21 +1,33 @@ -from eth.vm.forks.spurious_dragon.transactions import ( - SpuriousDragonTransaction, - SpuriousDragonUnsignedTransaction, -) +from eth_keys.datatypes import PrivateKey +from eth_typing import Address from eth.utils.transactions import ( create_transaction_signature, ) +from eth.vm.forks.spurious_dragon.transactions import ( + SpuriousDragonTransaction, + SpuriousDragonUnsignedTransaction, +) + class ByzantiumTransaction(SpuriousDragonTransaction): @classmethod - def create_unsigned_transaction(cls, *, nonce, gas_price, gas, to, value, data): + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> 'ByzantiumUnsignedTransaction': return ByzantiumUnsignedTransaction(nonce, gas_price, gas, to, value, data) class ByzantiumUnsignedTransaction(SpuriousDragonUnsignedTransaction): - def as_signed_transaction(self, private_key, chain_id=None): + def as_signed_transaction(self, + private_key: PrivateKey, + chain_id: int=None) -> ByzantiumTransaction: v, r, s = create_transaction_signature(self, private_key, chain_id=chain_id) return ByzantiumTransaction( nonce=self.nonce, diff --git a/eth/vm/forks/constantinople/__init__.py b/eth/vm/forks/constantinople/__init__.py index 552df68c97..9dd79459e4 100644 --- a/eth/vm/forks/constantinople/__init__.py +++ b/eth/vm/forks/constantinople/__init__.py @@ -28,11 +28,11 @@ class ConstantinopleVM(ByzantiumVM): _state_class = ConstantinopleState # type: Type[BaseState] # Methods - create_header_from_parent = staticmethod(create_constantinople_header_from_parent) - compute_difficulty = staticmethod(compute_constantinople_difficulty) + create_header_from_parent = staticmethod(create_constantinople_header_from_parent) # type: ignore # noqa: E501 + compute_difficulty = staticmethod(compute_constantinople_difficulty) # type: ignore configure_header = configure_constantinople_header get_uncle_reward = staticmethod(get_uncle_reward(EIP1234_BLOCK_REWARD)) @staticmethod - def get_block_reward(): + def get_block_reward() -> int: return EIP1234_BLOCK_REWARD diff --git a/eth/vm/forks/constantinople/storage.py b/eth/vm/forks/constantinople/storage.py index 4d99a1ae2a..61e96875e6 100644 --- a/eth/vm/forks/constantinople/storage.py +++ b/eth/vm/forks/constantinople/storage.py @@ -1,6 +1,8 @@ from eth.constants import ( UINT256 ) + +from eth.vm.computation import BaseComputation from eth.vm.forks.constantinople import ( constants ) @@ -10,7 +12,7 @@ ) -def sstore_eip1283(computation): +def sstore_eip1283(computation: BaseComputation) -> None: slot, value = computation.stack_pop(num_items=2, type_hint=UINT256) current_value = computation.state.account_db.get_storage( diff --git a/eth/vm/forks/constantinople/transactions.py b/eth/vm/forks/constantinople/transactions.py index 41c3dde00b..e3bf87b0ed 100644 --- a/eth/vm/forks/constantinople/transactions.py +++ b/eth/vm/forks/constantinople/transactions.py @@ -1,3 +1,6 @@ +from eth_keys.datatypes import PrivateKey +from eth_typing import Address + from eth.vm.forks.byzantium.transactions import ( ByzantiumTransaction, ByzantiumUnsignedTransaction, @@ -10,12 +13,21 @@ class ConstantinopleTransaction(ByzantiumTransaction): @classmethod - def create_unsigned_transaction(cls, *, nonce, gas_price, gas, to, value, data): + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> 'ConstantinopleUnsignedTransaction': return ConstantinopleUnsignedTransaction(nonce, gas_price, gas, to, value, data) class ConstantinopleUnsignedTransaction(ByzantiumUnsignedTransaction): - def as_signed_transaction(self, private_key, chain_id=None): + def as_signed_transaction(self, + private_key: PrivateKey, + chain_id: int=None) -> ConstantinopleTransaction: v, r, s = create_transaction_signature(self, private_key, chain_id=chain_id) return ConstantinopleTransaction( nonce=self.nonce, diff --git a/eth/vm/forks/frontier/__init__.py b/eth/vm/forks/frontier/__init__.py index 94f44ae4d0..a91f24df46 100644 --- a/eth/vm/forks/frontier/__init__.py +++ b/eth/vm/forks/frontier/__init__.py @@ -1,19 +1,19 @@ from typing import Type # noqa: F401 -from eth.rlp.blocks import BaseBlock # noqa: F401 -from eth.vm.state import BaseState # noqa: F401 - from eth.constants import ( BLOCK_REWARD, UNCLE_DEPTH_PENALTY_FACTOR, ) + +from eth.rlp.blocks import BaseBlock # noqa: F401 +from eth.rlp.headers import BlockHeader +from eth.rlp.logs import Log +from eth.rlp.receipts import Receipt +from eth.rlp.transactions import BaseTransaction + from eth.vm.base import VM -from eth.rlp.receipts import ( - Receipt, -) -from eth.rlp.logs import ( - Log, -) +from eth.vm.computation import BaseComputation +from eth.vm.state import BaseState # noqa: F401 from .blocks import FrontierBlock from .state import FrontierState @@ -25,7 +25,10 @@ from .validation import validate_frontier_transaction_against_header -def make_frontier_receipt(base_header, transaction, computation, state): +def make_frontier_receipt(base_header: BlockHeader, + transaction: BaseTransaction, + computation: BaseComputation, + state: BaseState) -> Receipt: # Reusable for other forks logs = [ @@ -62,22 +65,22 @@ class FrontierVM(VM): _state_class = FrontierState # type: Type[BaseState] # methods - create_header_from_parent = staticmethod(create_frontier_header_from_parent) - compute_difficulty = staticmethod(compute_frontier_difficulty) + create_header_from_parent = staticmethod(create_frontier_header_from_parent) # type: ignore + compute_difficulty = staticmethod(compute_frontier_difficulty) # type: ignore configure_header = configure_frontier_header - make_receipt = staticmethod(make_frontier_receipt) + make_receipt = staticmethod(make_frontier_receipt) # type: ignore validate_transaction_against_header = validate_frontier_transaction_against_header @staticmethod - def get_block_reward(): + def get_block_reward() -> int: return BLOCK_REWARD @staticmethod - def get_uncle_reward(block_number, uncle): + def get_uncle_reward(block_number: int, uncle: BaseBlock) -> int: return BLOCK_REWARD * ( UNCLE_DEPTH_PENALTY_FACTOR + uncle.block_number - block_number ) // UNCLE_DEPTH_PENALTY_FACTOR @classmethod - def get_nephew_reward(cls): + def get_nephew_reward(cls) -> int: return cls.get_block_reward() // 32 diff --git a/eth/vm/forks/frontier/blocks.py b/eth/vm/forks/frontier/blocks.py index 97176f01c2..377cbf751f 100644 --- a/eth/vm/forks/frontier/blocks.py +++ b/eth/vm/forks/frontier/blocks.py @@ -1,5 +1,7 @@ from typing import ( # noqa: F401 - List + Iterable, + List, + Type, ) import rlp @@ -11,20 +13,32 @@ BloomFilter, ) +from eth_typing import ( + Hash32, +) + from eth_hash.auto import keccak from eth.constants import ( EMPTY_UNCLE_HASH, ) -from eth.rlp.receipts import ( - Receipt, + +from eth.db.chain import ( + BaseChainDB, ) + from eth.rlp.blocks import ( BaseBlock, ) from eth.rlp.headers import ( BlockHeader, ) +from eth.rlp.receipts import ( + Receipt, +) +from eth.rlp.transactions import ( + BaseTransaction, +) from .transactions import ( FrontierTransaction, @@ -41,7 +55,10 @@ class FrontierBlock(BaseBlock): bloom_filter = None - def __init__(self, header, transactions=None, uncles=None): + def __init__(self, + header: BlockHeader, + transactions: Iterable[BaseTransaction]=None, + uncles: Iterable[BlockHeader]=None) -> None: if transactions is None: transactions = [] if uncles is None: @@ -60,36 +77,36 @@ def __init__(self, header, transactions=None, uncles=None): # Helpers # @property - def number(self): + def number(self) -> int: return self.header.block_number @property - def hash(self): + def hash(self) -> Hash32: return self.header.hash # # Transaction class for this block class # @classmethod - def get_transaction_class(cls): + def get_transaction_class(cls) -> Type[BaseTransaction]: return cls.transaction_class # # Receipts API # - def get_receipts(self, chaindb): + def get_receipts(self, chaindb: BaseChainDB) -> Iterable[Receipt]: return chaindb.get_receipts(self.header, Receipt) # # Header API # @classmethod - def from_header(cls, header, chaindb): + def from_header(cls, header: BlockHeader, chaindb: BaseChainDB) -> BaseBlock: """ Returns the block denoted by the given block header. """ if header.uncles_hash == EMPTY_UNCLE_HASH: - uncles = [] # type: List[bytes] + uncles = [] # type: List[BlockHeader] else: uncles = chaindb.get_block_uncles(header.uncles_hash) @@ -104,7 +121,7 @@ def from_header(cls, header, chaindb): # # Execution API # - def add_uncle(self, uncle): + def add_uncle(self, uncle: BlockHeader) -> "FrontierBlock": self.uncles.append(uncle) self.header.uncles_hash = keccak(rlp.encode(self.uncles)) return self diff --git a/eth/vm/forks/frontier/computation.py b/eth/vm/forks/frontier/computation.py index 0fa99330ba..08fdc618a4 100644 --- a/eth/vm/forks/frontier/computation.py +++ b/eth/vm/forks/frontier/computation.py @@ -1,18 +1,18 @@ +from eth import precompiles + from eth_hash.auto import keccak from eth.constants import ( GAS_CODEDEPOSIT, STACK_DEPTH_LIMIT, ) -from eth import precompiles -from eth.vm.computation import ( - BaseComputation -) + from eth.exceptions import ( OutOfGas, InsufficientFunds, StackDepthLimit, ) + from eth.utils.address import ( force_bytes_to_address, ) @@ -20,8 +20,13 @@ encode_hex, ) +from eth.vm.computation import ( + BaseComputation, +) + from .opcodes import FRONTIER_OPCODES + FRONTIER_PRECOMPILES = { force_bytes_to_address(b'\x01'): precompiles.ecrecover, force_bytes_to_address(b'\x02'): precompiles.sha256, @@ -39,7 +44,7 @@ class FrontierComputation(BaseComputation): opcodes = FRONTIER_OPCODES _precompiles = FRONTIER_PRECOMPILES - def apply_message(self): + def apply_message(self) -> BaseComputation: snapshot = self.state.snapshot() if self.msg.depth > STACK_DEPTH_LIMIT: @@ -78,7 +83,7 @@ def apply_message(self): return computation - def apply_create_message(self): + def apply_create_message(self) -> BaseComputation: computation = self.apply_message() if computation.is_error: diff --git a/eth/vm/forks/frontier/headers.py b/eth/vm/forks/frontier/headers.py index 36c81352c6..3337010bc9 100644 --- a/eth/vm/forks/frontier/headers.py +++ b/eth/vm/forks/frontier/headers.py @@ -1,5 +1,10 @@ from __future__ import absolute_import +from typing import ( + Any, + TYPE_CHECKING, +) + from eth.validation import ( validate_gt, validate_header_params_for_configuration, @@ -24,8 +29,11 @@ FRONTIER_DIFFICULTY_ADJUSTMENT_CUTOFF ) +if TYPE_CHECKING: + from eth.vm.forks.frontier import FrontierVM # noqa: F401 + -def compute_frontier_difficulty(parent_header, timestamp): +def compute_frontier_difficulty(parent_header: BlockHeader, timestamp: int) -> int: """ Computes the difficulty for a frontier block based on the parent block. """ @@ -65,7 +73,8 @@ def compute_frontier_difficulty(parent_header, timestamp): return difficulty -def create_frontier_header_from_parent(parent_header, **header_params): +def create_frontier_header_from_parent(parent_header: BlockHeader, + **header_params: Any) -> BlockHeader: if 'difficulty' not in header_params: # Use setdefault to ensure the new header has the same timestamp we use to calculate its # difficulty. @@ -85,7 +94,7 @@ def create_frontier_header_from_parent(parent_header, **header_params): return header -def configure_frontier_header(vm, **header_params): +def configure_frontier_header(vm: "FrontierVM", **header_params: Any) -> BlockHeader: validate_header_params_for_configuration(header_params) with vm.block.header.build_changeset(**header_params) as changeset: diff --git a/eth/vm/forks/frontier/opcodes.py b/eth/vm/forks/frontier/opcodes.py index 01b31fd95a..b9b32ea30c 100644 --- a/eth/vm/forks/frontier/opcodes.py +++ b/eth/vm/forks/frontier/opcodes.py @@ -1,4 +1,5 @@ from eth import constants + from eth.vm import mnemonics from eth.vm import opcode_values from eth.vm.logic import ( @@ -17,7 +18,9 @@ swap, system, ) -from eth.vm.opcode import as_opcode +from eth.vm.opcode import ( + as_opcode, +) FRONTIER_OPCODES = { diff --git a/eth/vm/forks/frontier/state.py b/eth/vm/forks/frontier/state.py index 31f96ad753..19a996df04 100644 --- a/eth/vm/forks/frontier/state.py +++ b/eth/vm/forks/frontier/state.py @@ -10,12 +10,9 @@ from eth.exceptions import ( ContractCreationCollision, ) -from eth.vm.message import ( - Message, -) -from eth.vm.state import ( - BaseState, - BaseTransactionExecutor, + +from eth.rlp.transactions import ( + BaseTransaction, ) from eth.utils.address import ( @@ -25,6 +22,18 @@ encode_hex, ) +from eth.vm.computation import ( + BaseComputation, +) +from eth.vm.message import ( + Message, +) +from eth.vm.state import ( + BaseState, + BaseTransactionExecutor, +) + + from .computation import FrontierComputation from .constants import REFUND_SELFDESTRUCT from .transaction_context import ( # noqa: F401 @@ -36,7 +45,7 @@ class FrontierTransactionExecutor(BaseTransactionExecutor): - def validate_transaction(self, transaction): + def validate_transaction(self, transaction: BaseTransaction) -> BaseTransaction: # Validate the transaction transaction.validate() @@ -44,7 +53,7 @@ def validate_transaction(self, transaction): return transaction - def build_evm_message(self, transaction): + def build_evm_message(self, transaction: BaseTransaction) -> Message: gas_fee = transaction.gas * transaction.gas_price @@ -96,7 +105,7 @@ def build_evm_message(self, transaction): ) return message - def build_computation(self, message, transaction): + def build_computation(self, message: Message, transaction: BaseTransaction) -> BaseComputation: """Apply the message to the VM.""" transaction_context = self.vm_state.get_transaction_context(transaction) if message.is_create: @@ -129,7 +138,9 @@ def build_computation(self, message, transaction): return computation - def finalize_computation(self, transaction, computation): + def finalize_computation(self, + transaction: BaseTransaction, + computation: BaseComputation) -> BaseComputation: # Self Destruct Refunds num_deletions = len(computation.get_accounts_for_deletion()) if num_deletions: @@ -181,9 +192,9 @@ class FrontierState(BaseState): account_db_class = AccountDB # Type[BaseAccountDB] transaction_executor = FrontierTransactionExecutor # Type[BaseTransactionExecutor] - def validate_transaction(self, transaction): + def validate_transaction(self, transaction: BaseTransaction) -> None: validate_frontier_transaction(self.account_db, transaction) - def execute_transaction(self, transaction): + def execute_transaction(self, transaction: BaseTransaction) -> BaseTransactionExecutor: executor = self.get_transaction_executor() return executor(transaction) diff --git a/eth/vm/forks/frontier/transactions.py b/eth/vm/forks/frontier/transactions.py index f6f1eea07d..70eb05ef7d 100644 --- a/eth/vm/forks/frontier/transactions.py +++ b/eth/vm/forks/frontier/transactions.py @@ -1,5 +1,11 @@ import rlp +from eth_keys.datatypes import PrivateKey + +from eth_typing import ( + Address, +) + from eth.constants import ( CREATE_CONTRACT_ADDRESS, GAS_TX, @@ -30,10 +36,15 @@ class FrontierTransaction(BaseTransaction): - v_max = 28 - v_min = 27 + @property + def v_min(self) -> int: + return 27 + + @property + def v_max(self) -> int: + return 28 - def validate(self): + def validate(self) -> None: validate_uint256(self.nonce, title="Transaction.nonce") validate_uint256(self.gas_price, title="Transaction.gas_price") validate_uint256(self.gas, title="Transaction.gas") @@ -56,16 +67,16 @@ def validate(self): super().validate() - def check_signature_validity(self): + def check_signature_validity(self) -> None: validate_transaction_signature(self) - def get_sender(self): + def get_sender(self) -> Address: return extract_transaction_sender(self) - def get_intrinsic_gas(self): + def get_intrinsic_gas(self) -> int: return _get_frontier_intrinsic_gas(self.data) - def get_message_for_signing(self): + def get_message_for_signing(self) -> bytes: return rlp.encode(FrontierUnsignedTransaction( nonce=self.nonce, gas_price=self.gas_price, @@ -76,13 +87,20 @@ def get_message_for_signing(self): )) @classmethod - def create_unsigned_transaction(cls, *, nonce, gas_price, gas, to, value, data): + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> 'FrontierUnsignedTransaction': return FrontierUnsignedTransaction(nonce, gas_price, gas, to, value, data) class FrontierUnsignedTransaction(BaseUnsignedTransaction): - def validate(self): + def validate(self) -> None: validate_uint256(self.nonce, title="Transaction.nonce") validate_is_integer(self.gas_price, title="Transaction.gas_price") validate_uint256(self.gas, title="Transaction.gas") @@ -92,7 +110,7 @@ def validate(self): validate_is_bytes(self.data, title="Transaction.data") super().validate() - def as_signed_transaction(self, private_key): + def as_signed_transaction(self, private_key: PrivateKey) -> FrontierTransaction: v, r, s = create_transaction_signature(self, private_key) return FrontierTransaction( nonce=self.nonce, @@ -106,11 +124,11 @@ def as_signed_transaction(self, private_key): s=s, ) - def get_intrinsic_gas(self): + def get_intrinsic_gas(self) -> int: return _get_frontier_intrinsic_gas(self.data) -def _get_frontier_intrinsic_gas(transaction_data): +def _get_frontier_intrinsic_gas(transaction_data: bytes) -> int: num_zero_bytes = transaction_data.count(b'\x00') num_non_zero_bytes = len(transaction_data) - num_zero_bytes return ( diff --git a/eth/vm/forks/frontier/validation.py b/eth/vm/forks/frontier/validation.py index 1a6a6c7ad5..ea85bfad39 100644 --- a/eth/vm/forks/frontier/validation.py +++ b/eth/vm/forks/frontier/validation.py @@ -2,8 +2,15 @@ ValidationError, ) +from eth.db.account import BaseAccountDB -def validate_frontier_transaction(account_db, transaction): +from eth.rlp.headers import BlockHeader +from eth.rlp.transactions import BaseTransaction + +from eth.vm.base import BaseVM + + +def validate_frontier_transaction(account_db: BaseAccountDB, transaction: BaseTransaction) -> None: gas_cost = transaction.gas * transaction.gas_price sender_balance = account_db.get_balance(transaction.sender) @@ -21,7 +28,9 @@ def validate_frontier_transaction(account_db, transaction): raise ValidationError("Invalid transaction nonce") -def validate_frontier_transaction_against_header(_vm, base_header, transaction): +def validate_frontier_transaction_against_header(_vm: BaseVM, + base_header: BlockHeader, + transaction: BaseTransaction) -> None: if base_header.gas_used + transaction.gas > base_header.gas_limit: raise ValidationError( "Transaction exceeds gas limit: using {}, bringing total to {}, but limit is {}".format( diff --git a/eth/vm/forks/homestead/__init__.py b/eth/vm/forks/homestead/__init__.py index 6d9129237b..127958620b 100644 --- a/eth/vm/forks/homestead/__init__.py +++ b/eth/vm/forks/homestead/__init__.py @@ -35,6 +35,6 @@ class HomesteadVM(MetaHomesteadVM): _state_class = HomesteadState # type: Type[BaseState] # method overrides - create_header_from_parent = staticmethod(create_homestead_header_from_parent) - compute_difficulty = staticmethod(compute_homestead_difficulty) + create_header_from_parent = staticmethod(create_homestead_header_from_parent) # type: ignore + compute_difficulty = staticmethod(compute_homestead_difficulty) # type: ignore configure_header = configure_homestead_header diff --git a/eth/vm/forks/homestead/computation.py b/eth/vm/forks/homestead/computation.py index 4638fe1019..ee745a7b9e 100644 --- a/eth/vm/forks/homestead/computation.py +++ b/eth/vm/forks/homestead/computation.py @@ -7,6 +7,7 @@ from eth.utils.hexadecimal import ( encode_hex, ) +from eth.vm.computation import BaseComputation from eth.vm.forks.frontier.computation import ( FrontierComputation, ) @@ -22,7 +23,7 @@ class HomesteadComputation(FrontierComputation): # Override opcodes = HOMESTEAD_OPCODES - def apply_create_message(self): + def apply_create_message(self) -> BaseComputation: snapshot = self.state.snapshot() computation = self.apply_message() diff --git a/eth/vm/forks/homestead/headers.py b/eth/vm/forks/homestead/headers.py index 1acd10769c..1d24e06886 100644 --- a/eth/vm/forks/homestead/headers.py +++ b/eth/vm/forks/homestead/headers.py @@ -1,10 +1,10 @@ -from eth_utils import ( - decode_hex, +from typing import ( + Any, + TYPE_CHECKING, ) -from eth.validation import ( - validate_gt, - validate_header_params_for_configuration, +from eth_utils import ( + decode_hex, ) from eth.constants import ( @@ -13,9 +13,14 @@ BOMB_EXPONENTIAL_PERIOD, BOMB_EXPONENTIAL_FREE_PERIODS, ) +from eth.rlp.headers import BlockHeader from eth.utils.db import ( get_parent_header, ) +from eth.validation import ( + validate_gt, + validate_header_params_for_configuration, +) from eth.vm.forks.frontier.headers import ( create_frontier_header_from_parent, ) @@ -24,8 +29,11 @@ HOMESTEAD_DIFFICULTY_ADJUSTMENT_CUTOFF ) +if TYPE_CHECKING: + from eth.vm.forks.homestead import HomesteadVM # noqa: F401 + -def compute_homestead_difficulty(parent_header, timestamp): +def compute_homestead_difficulty(parent_header: BlockHeader, timestamp: int) -> int: """ Computes the difficulty for a homestead block based on the parent block. """ @@ -47,7 +55,8 @@ def compute_homestead_difficulty(parent_header, timestamp): return difficulty -def create_homestead_header_from_parent(parent_header, **header_params): +def create_homestead_header_from_parent(parent_header: BlockHeader, + **header_params: Any) -> BlockHeader: if 'difficulty' not in header_params: # Use setdefault to ensure the new header has the same timestamp we use to calculate its # difficulty. @@ -59,7 +68,7 @@ def create_homestead_header_from_parent(parent_header, **header_params): return create_frontier_header_from_parent(parent_header, **header_params) -def configure_homestead_header(vm, **header_params): +def configure_homestead_header(vm: "HomesteadVM", **header_params: Any) -> BlockHeader: validate_header_params_for_configuration(header_params) with vm.block.header.build_changeset(**header_params) as changeset: diff --git a/eth/vm/forks/homestead/state.py b/eth/vm/forks/homestead/state.py index 41df8d688b..6dccd70a4e 100644 --- a/eth/vm/forks/homestead/state.py +++ b/eth/vm/forks/homestead/state.py @@ -1,3 +1,5 @@ +from eth.rlp.transactions import BaseTransaction + from eth.vm.forks.frontier.state import ( FrontierState, FrontierTransactionExecutor, @@ -10,7 +12,7 @@ class HomesteadState(FrontierState): computation_class = HomesteadComputation - def validate_transaction(self, transaction): + def validate_transaction(self, transaction: BaseTransaction) -> None: validate_homestead_transaction(self.account_db, transaction) diff --git a/eth/vm/forks/homestead/transactions.py b/eth/vm/forks/homestead/transactions.py index 00f514f5db..745bbbe038 100644 --- a/eth/vm/forks/homestead/transactions.py +++ b/eth/vm/forks/homestead/transactions.py @@ -1,5 +1,9 @@ import rlp +from eth_keys.datatypes import PrivateKey + +from eth_typing import Address + from eth.constants import ( GAS_TX, GAS_TXCREATE, @@ -7,6 +11,9 @@ GAS_TXDATANONZERO, CREATE_CONTRACT_ADDRESS, ) + +from eth.rlp.transactions import BaseTransaction + from eth.validation import ( validate_lt_secpk1n2, ) @@ -22,14 +29,14 @@ class HomesteadTransaction(FrontierTransaction): - def validate(self): + def validate(self) -> None: super().validate() validate_lt_secpk1n2(self.s, title="Transaction.s") - def get_intrinsic_gas(self): + def get_intrinsic_gas(self) -> int: return _get_homestead_intrinsic_gas(self) - def get_message_for_signing(self): + def get_message_for_signing(self) -> bytes: return rlp.encode(HomesteadUnsignedTransaction( nonce=self.nonce, gas_price=self.gas_price, @@ -40,12 +47,19 @@ def get_message_for_signing(self): )) @classmethod - def create_unsigned_transaction(cls, *, nonce, gas_price, gas, to, value, data): + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> 'HomesteadUnsignedTransaction': return HomesteadUnsignedTransaction(nonce, gas_price, gas, to, value, data) class HomesteadUnsignedTransaction(FrontierUnsignedTransaction): - def as_signed_transaction(self, private_key): + def as_signed_transaction(self, private_key: PrivateKey) -> HomesteadTransaction: v, r, s = create_transaction_signature(self, private_key) return HomesteadTransaction( nonce=self.nonce, @@ -59,11 +73,11 @@ def as_signed_transaction(self, private_key): s=s, ) - def get_intrinsic_gas(self): + def get_intrinsic_gas(self) -> int: return _get_homestead_intrinsic_gas(self) -def _get_homestead_intrinsic_gas(transaction): +def _get_homestead_intrinsic_gas(transaction: BaseTransaction) -> int: num_zero_bytes = transaction.data.count(b'\x00') num_non_zero_bytes = len(transaction.data) - num_zero_bytes if transaction.to == CREATE_CONTRACT_ADDRESS: diff --git a/eth/vm/forks/homestead/validation.py b/eth/vm/forks/homestead/validation.py index 4d0db7a93c..add2be3c68 100644 --- a/eth/vm/forks/homestead/validation.py +++ b/eth/vm/forks/homestead/validation.py @@ -6,12 +6,16 @@ SECPK1_N, ) +from eth.db.account import BaseAccountDB + +from eth.rlp.transactions import BaseTransaction + from eth.vm.forks.frontier.validation import ( validate_frontier_transaction, ) -def validate_homestead_transaction(account_db, transaction): +def validate_homestead_transaction(account_db: BaseAccountDB, transaction: BaseTransaction) -> None: if transaction.s > SECPK1_N // 2 or transaction.s == 0: raise ValidationError("Invalid signature S value") diff --git a/eth/vm/forks/spurious_dragon/computation.py b/eth/vm/forks/spurious_dragon/computation.py index a8a77f823e..a053b3bd74 100644 --- a/eth/vm/forks/spurious_dragon/computation.py +++ b/eth/vm/forks/spurious_dragon/computation.py @@ -7,6 +7,7 @@ from eth.utils.hexadecimal import ( encode_hex, ) +from eth.vm.computation import BaseComputation from eth.vm.forks.homestead.computation import ( HomesteadComputation, ) @@ -23,7 +24,7 @@ class SpuriousDragonComputation(HomesteadComputation): # Override opcodes = SPURIOUS_DRAGON_OPCODES - def apply_create_message(self): + def apply_create_message(self) -> BaseComputation: snapshot = self.state.snapshot() # EIP161 nonce incrementation diff --git a/eth/vm/forks/spurious_dragon/state.py b/eth/vm/forks/spurious_dragon/state.py index d1b5e93671..be0019c0b9 100644 --- a/eth/vm/forks/spurious_dragon/state.py +++ b/eth/vm/forks/spurious_dragon/state.py @@ -1,7 +1,11 @@ +from eth.rlp.transactions import BaseTransaction from eth.utils.hexadecimal import ( encode_hex, ) + +from eth.vm.computation import BaseComputation + from eth.vm.forks.homestead.state import ( HomesteadState, HomesteadTransactionExecutor, @@ -12,7 +16,9 @@ class SpuriousDragonTransactionExecutor(HomesteadTransactionExecutor): - def finalize_computation(self, transaction, computation): + def finalize_computation(self, + transaction: BaseTransaction, + computation: BaseComputation) -> BaseComputation: computation = super().finalize_computation(transaction, computation) # diff --git a/eth/vm/forks/spurious_dragon/transactions.py b/eth/vm/forks/spurious_dragon/transactions.py index 94a1fe83d4..915a49aa8c 100644 --- a/eth/vm/forks/spurious_dragon/transactions.py +++ b/eth/vm/forks/spurious_dragon/transactions.py @@ -1,3 +1,11 @@ +from typing import Optional + +from eth_keys.datatypes import PrivateKey + +from eth_typing import ( + Address, +) + from eth_utils import ( int_to_big_endian, ) @@ -17,7 +25,7 @@ class SpuriousDragonTransaction(HomesteadTransaction): - def get_message_for_signing(self): + def get_message_for_signing(self) -> bytes: if is_eip_155_signed_transaction(self): txn_parts = rlp.decode(rlp.encode(self)) txn_parts_for_signing = txn_parts[:-3] + [int_to_big_endian(self.chain_id), b'', b''] @@ -33,25 +41,32 @@ def get_message_for_signing(self): )) @classmethod - def create_unsigned_transaction(cls, *, nonce, gas_price, gas, to, value, data): + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> 'SpuriousDragonUnsignedTransaction': return SpuriousDragonUnsignedTransaction(nonce, gas_price, gas, to, value, data) @property - def chain_id(self): + def chain_id(self) -> Optional[int]: if is_eip_155_signed_transaction(self): return extract_chain_id(self.v) else: return None @property - def v_min(self): + def v_min(self) -> int: if is_eip_155_signed_transaction(self): return 35 + (2 * self.chain_id) else: return 27 @property - def v_max(self): + def v_max(self) -> int: if is_eip_155_signed_transaction(self): return 36 + (2 * self.chain_id) else: @@ -59,7 +74,9 @@ def v_max(self): class SpuriousDragonUnsignedTransaction(HomesteadUnsignedTransaction): - def as_signed_transaction(self, private_key, chain_id=None): + def as_signed_transaction(self, + private_key: PrivateKey, + chain_id: int=None) -> SpuriousDragonTransaction: v, r, s = create_transaction_signature(self, private_key, chain_id=chain_id) return SpuriousDragonTransaction( nonce=self.nonce, diff --git a/eth/vm/forks/spurious_dragon/utils.py b/eth/vm/forks/spurious_dragon/utils.py index 0318b1096a..dd144ddab1 100644 --- a/eth/vm/forks/spurious_dragon/utils.py +++ b/eth/vm/forks/spurious_dragon/utils.py @@ -1,3 +1,5 @@ +from typing import Iterable + from eth_utils import to_set from eth import constants @@ -6,12 +8,14 @@ force_bytes_to_address, ) +from eth.vm.computation import BaseComputation + THREE = force_bytes_to_address(b'\x03') @to_set -def collect_touched_accounts(computation): +def collect_touched_accounts(computation: BaseComputation) -> Iterable[bytes]: """ Collect all of the accounts that *may* need to be deleted based on EIP161: diff --git a/eth/vm/logic/arithmetic.py b/eth/vm/logic/arithmetic.py index 3809d77d11..2337ac2905 100644 --- a/eth/vm/logic/arithmetic.py +++ b/eth/vm/logic/arithmetic.py @@ -10,8 +10,10 @@ ceil8, ) +from eth.vm.computation import BaseComputation -def add(computation): + +def add(computation: BaseComputation) -> None: """ Addition """ @@ -22,7 +24,7 @@ def add(computation): computation.stack_push(result) -def addmod(computation): +def addmod(computation: BaseComputation) -> None: """ Modulo Addition """ @@ -36,7 +38,7 @@ def addmod(computation): computation.stack_push(result) -def sub(computation): +def sub(computation: BaseComputation) -> None: """ Subtraction """ @@ -47,7 +49,7 @@ def sub(computation): computation.stack_push(result) -def mod(computation): +def mod(computation: BaseComputation) -> None: """ Modulo """ @@ -61,7 +63,7 @@ def mod(computation): computation.stack_push(result) -def smod(computation): +def smod(computation: BaseComputation) -> None: """ Signed Modulo """ @@ -80,7 +82,7 @@ def smod(computation): computation.stack_push(signed_to_unsigned(result)) -def mul(computation): +def mul(computation: BaseComputation) -> None: """ Multiplication """ @@ -91,7 +93,7 @@ def mul(computation): computation.stack_push(result) -def mulmod(computation): +def mulmod(computation: BaseComputation) -> None: """ Modulo Multiplication """ @@ -104,7 +106,7 @@ def mulmod(computation): computation.stack_push(result) -def div(computation): +def div(computation: BaseComputation) -> None: """ Division """ @@ -118,7 +120,7 @@ def div(computation): computation.stack_push(result) -def sdiv(computation): +def sdiv(computation: BaseComputation) -> None: """ Signed Division """ @@ -138,7 +140,7 @@ def sdiv(computation): @curry -def exp(computation, gas_per_byte): +def exp(computation: BaseComputation, gas_per_byte: int) -> None: """ Exponentiation """ @@ -162,7 +164,7 @@ def exp(computation, gas_per_byte): computation.stack_push(result) -def signextend(computation): +def signextend(computation: BaseComputation) -> None: """ Signed Extend """ @@ -181,7 +183,7 @@ def signextend(computation): computation.stack_push(result) -def shl(computation): +def shl(computation: BaseComputation) -> None: """ Bitwise left shift """ @@ -195,7 +197,7 @@ def shl(computation): computation.stack_push(result) -def shr(computation): +def shr(computation: BaseComputation) -> None: """ Bitwise right shift """ @@ -209,7 +211,7 @@ def shr(computation): computation.stack_push(result) -def sar(computation): +def sar(computation: BaseComputation) -> None: """ Arithmetic bitwise right shift """ diff --git a/eth/vm/logic/block.py b/eth/vm/logic/block.py index af3696d423..b4a269701a 100644 --- a/eth/vm/logic/block.py +++ b/eth/vm/logic/block.py @@ -1,7 +1,9 @@ from eth import constants +from eth.vm.computation import BaseComputation -def blockhash(computation): + +def blockhash(computation: BaseComputation) -> None: block_number = computation.stack_pop(type_hint=constants.UINT256) block_hash = computation.state.get_ancestor_hash(block_number) @@ -9,21 +11,21 @@ def blockhash(computation): computation.stack_push(block_hash) -def coinbase(computation): +def coinbase(computation: BaseComputation) -> None: computation.stack_push(computation.state.coinbase) -def timestamp(computation): +def timestamp(computation: BaseComputation) -> None: computation.stack_push(computation.state.timestamp) -def number(computation): +def number(computation: BaseComputation) -> None: computation.stack_push(computation.state.block_number) -def difficulty(computation): +def difficulty(computation: BaseComputation) -> None: computation.stack_push(computation.state.difficulty) -def gaslimit(computation): +def gaslimit(computation: BaseComputation) -> None: computation.stack_push(computation.state.gas_limit) diff --git a/eth/vm/logic/call.py b/eth/vm/logic/call.py index f8b3dec01b..7769c29e28 100644 --- a/eth/vm/logic/call.py +++ b/eth/vm/logic/call.py @@ -3,8 +3,16 @@ abstractmethod ) +from typing import ( + Tuple, +) + from eth import constants +from eth_typing import ( + Address, +) + from eth.exceptions import ( OutOfGas, WriteProtection, @@ -13,27 +21,40 @@ Opcode, ) +from eth.vm.computation import BaseComputation + from eth.utils.address import ( force_bytes_to_address, ) +CallParams = Tuple[int, int, Address, Address, Address, int, int, int, int, bool, bool] + + class BaseCall(Opcode, ABC): @abstractmethod - def compute_msg_extra_gas(self, computation, gas, to, value): + def compute_msg_extra_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> int: raise NotImplementedError("Must be implemented by subclasses") @abstractmethod - def get_call_params(self, computation): + def get_call_params(self, computation: BaseComputation) -> CallParams: raise NotImplementedError("Must be implemented by subclasses") - def compute_msg_gas(self, computation, gas, to, value): + def compute_msg_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> Tuple[int, int]: extra_gas = self.compute_msg_extra_gas(computation, gas, to, value) total_fee = gas + extra_gas child_msg_gas = gas + (constants.GAS_CALLSTIPEND if value else 0) return child_msg_gas, total_fee - def __call__(self, computation): + def __call__(self, computation: BaseComputation) -> None: computation.consume_gas( self.gas_cost, reason=self.mnemonic, @@ -132,14 +153,18 @@ def __call__(self, computation): class Call(BaseCall): - def compute_msg_extra_gas(self, computation, gas, to, value): + def compute_msg_extra_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> int: account_exists = computation.state.account_db.account_exists(to) transfer_gas_fee = constants.GAS_CALLVALUE if value else 0 create_gas_fee = constants.GAS_NEWACCOUNT if not account_exists else 0 return transfer_gas_fee + create_gas_fee - def get_call_params(self, computation): + def get_call_params(self, computation: BaseComputation) -> CallParams: gas = computation.stack_pop(type_hint=constants.UINT256) to = force_bytes_to_address(computation.stack_pop(type_hint=constants.BYTES)) @@ -167,10 +192,14 @@ def get_call_params(self, computation): class CallCode(BaseCall): - def compute_msg_extra_gas(self, computation, gas, to, value): + def compute_msg_extra_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> int: return constants.GAS_CALLVALUE if value else 0 - def get_call_params(self, computation): + def get_call_params(self, computation: BaseComputation) -> CallParams: gas = computation.stack_pop(type_hint=constants.UINT256) code_address = force_bytes_to_address(computation.stack_pop(type_hint=constants.BYTES)) @@ -201,13 +230,21 @@ def get_call_params(self, computation): class DelegateCall(BaseCall): - def compute_msg_gas(self, computation, gas, to, value): + def compute_msg_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> Tuple[int, int]: return gas, gas - def compute_msg_extra_gas(self, computation, gas, to, value): + def compute_msg_extra_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> int: return 0 - def get_call_params(self, computation): + def get_call_params(self, computation: BaseComputation) -> CallParams: gas = computation.stack_pop(type_hint=constants.UINT256) code_address = force_bytes_to_address(computation.stack_pop(type_hint=constants.BYTES)) @@ -241,7 +278,11 @@ def get_call_params(self, computation): # EIP150 # class CallEIP150(Call): - def compute_msg_gas(self, computation, gas, to, value): + def compute_msg_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> Tuple[int, int]: extra_gas = self.compute_msg_extra_gas(computation, gas, to, value) return compute_eip150_msg_gas( computation=computation, @@ -254,7 +295,11 @@ def compute_msg_gas(self, computation, gas, to, value): class CallCodeEIP150(CallCode): - def compute_msg_gas(self, computation, gas, to, value): + def compute_msg_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> Tuple[int, int]: extra_gas = self.compute_msg_extra_gas(computation, gas, to, value) return compute_eip150_msg_gas( computation=computation, @@ -267,7 +312,11 @@ def compute_msg_gas(self, computation, gas, to, value): class DelegateCallEIP150(DelegateCall): - def compute_msg_gas(self, computation, gas, to, value): + def compute_msg_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> Tuple[int, int]: extra_gas = self.compute_msg_extra_gas(computation, gas, to, value) callstipend = 0 return compute_eip150_msg_gas( @@ -280,11 +329,17 @@ def compute_msg_gas(self, computation, gas, to, value): ) -def max_child_gas_eip150(gas): +def max_child_gas_eip150(gas: int) -> int: return gas - (gas // 64) -def compute_eip150_msg_gas(*, computation, gas, extra_gas, value, mnemonic, callstipend): +def compute_eip150_msg_gas(*, + computation: BaseComputation, + gas: int, + extra_gas: int, + value: int, + mnemonic: str, + callstipend: int) -> Tuple[int, int]: if computation.get_gas_remaining() < extra_gas: # It feels wrong to raise an OutOfGas exception outside of GasMeter, # but I don't see an easy way around it. @@ -305,7 +360,11 @@ def compute_eip150_msg_gas(*, computation, gas, extra_gas, value, mnemonic, call # EIP161 # class CallEIP161(CallEIP150): - def compute_msg_extra_gas(self, computation, gas, to, value): + def compute_msg_extra_gas(self, + computation: BaseComputation, + gas: int, + to: Address, + value: int) -> int: account_is_dead = ( not computation.state.account_db.account_exists(to) or computation.state.account_db.account_is_empty(to) @@ -320,7 +379,7 @@ def compute_msg_extra_gas(self, computation, gas, to, value): # Byzantium # class StaticCall(CallEIP161): - def get_call_params(self, computation): + def get_call_params(self, computation: BaseComputation) -> CallParams: gas = computation.stack_pop(type_hint=constants.UINT256) to = force_bytes_to_address(computation.stack_pop(type_hint=constants.BYTES)) @@ -347,7 +406,7 @@ def get_call_params(self, computation): class CallByzantium(CallEIP161): - def get_call_params(self, computation): + def get_call_params(self, computation: BaseComputation) -> CallParams: call_params = super().get_call_params(computation) value = call_params[1] if computation.msg.is_static and value != 0: diff --git a/eth/vm/logic/comparison.py b/eth/vm/logic/comparison.py index 9ae2c0c158..24951346cf 100644 --- a/eth/vm/logic/comparison.py +++ b/eth/vm/logic/comparison.py @@ -5,8 +5,10 @@ unsigned_to_signed, ) +from eth.vm.computation import BaseComputation -def lt(computation): + +def lt(computation: BaseComputation) -> None: """ Lesser Comparison """ @@ -20,7 +22,7 @@ def lt(computation): computation.stack_push(result) -def gt(computation): +def gt(computation: BaseComputation) -> None: """ Greater Comparison """ @@ -34,7 +36,7 @@ def gt(computation): computation.stack_push(result) -def slt(computation): +def slt(computation: BaseComputation) -> None: """ Signed Lesser Comparison """ @@ -51,7 +53,7 @@ def slt(computation): computation.stack_push(signed_to_unsigned(result)) -def sgt(computation): +def sgt(computation: BaseComputation) -> None: """ Signed Greater Comparison """ @@ -68,7 +70,7 @@ def sgt(computation): computation.stack_push(signed_to_unsigned(result)) -def eq(computation): +def eq(computation: BaseComputation) -> None: """ Equality """ @@ -82,7 +84,7 @@ def eq(computation): computation.stack_push(result) -def iszero(computation): +def iszero(computation: BaseComputation) -> None: """ Not """ @@ -96,7 +98,7 @@ def iszero(computation): computation.stack_push(result) -def and_op(computation): +def and_op(computation: BaseComputation) -> None: """ Bitwise And """ @@ -107,7 +109,7 @@ def and_op(computation): computation.stack_push(result) -def or_op(computation): +def or_op(computation: BaseComputation) -> None: """ Bitwise Or """ @@ -118,7 +120,7 @@ def or_op(computation): computation.stack_push(result) -def xor(computation): +def xor(computation: BaseComputation) -> None: """ Bitwise XOr """ @@ -129,7 +131,7 @@ def xor(computation): computation.stack_push(result) -def not_op(computation): +def not_op(computation: BaseComputation) -> None: """ Not """ @@ -140,7 +142,7 @@ def not_op(computation): computation.stack_push(result) -def byte_op(computation): +def byte_op(computation: BaseComputation) -> None: """ Bitwise And """ diff --git a/eth/vm/logic/context.py b/eth/vm/logic/context.py index 1b14886ba2..117da4a74e 100644 --- a/eth/vm/logic/context.py +++ b/eth/vm/logic/context.py @@ -1,4 +1,5 @@ from eth import constants + from eth.exceptions import ( OutOfBoundsRead, ) @@ -10,30 +11,32 @@ ceil32, ) +from eth.vm.computation import BaseComputation + -def balance(computation): +def balance(computation: BaseComputation) -> None: addr = force_bytes_to_address(computation.stack_pop(type_hint=constants.BYTES)) balance = computation.state.account_db.get_balance(addr) computation.stack_push(balance) -def origin(computation): +def origin(computation: BaseComputation) -> None: computation.stack_push(computation.transaction_context.origin) -def address(computation): +def address(computation: BaseComputation) -> None: computation.stack_push(computation.msg.storage_address) -def caller(computation): +def caller(computation: BaseComputation) -> None: computation.stack_push(computation.msg.sender) -def callvalue(computation): +def callvalue(computation: BaseComputation) -> None: computation.stack_push(computation.msg.value) -def calldataload(computation): +def calldataload(computation: BaseComputation) -> None: """ Load call data into memory. """ @@ -46,12 +49,12 @@ def calldataload(computation): computation.stack_push(normalized_value) -def calldatasize(computation): +def calldatasize(computation: BaseComputation) -> None: size = len(computation.msg.data) computation.stack_push(size) -def calldatacopy(computation): +def calldatacopy(computation: BaseComputation) -> None: ( mem_start_position, calldata_start_position, @@ -71,12 +74,12 @@ def calldatacopy(computation): computation.memory_write(mem_start_position, size, padded_value) -def codesize(computation): +def codesize(computation: BaseComputation) -> None: size = len(computation.code) computation.stack_push(size) -def codecopy(computation): +def codecopy(computation: BaseComputation) -> None: ( mem_start_position, code_start_position, @@ -101,18 +104,18 @@ def codecopy(computation): computation.memory_write(mem_start_position, size, padded_code_bytes) -def gasprice(computation): +def gasprice(computation: BaseComputation) -> None: computation.stack_push(computation.transaction_context.gas_price) -def extcodesize(computation): +def extcodesize(computation: BaseComputation) -> None: account = force_bytes_to_address(computation.stack_pop(type_hint=constants.BYTES)) code_size = len(computation.state.account_db.get_code(account)) computation.stack_push(code_size) -def extcodecopy(computation): +def extcodecopy(computation: BaseComputation) -> None: account = force_bytes_to_address(computation.stack_pop(type_hint=constants.BYTES)) ( mem_start_position, @@ -138,7 +141,7 @@ def extcodecopy(computation): computation.memory_write(mem_start_position, size, padded_code_bytes) -def extcodehash(computation): +def extcodehash(computation: BaseComputation) -> None: """ Return the code hash for a given address. EIP: https://github.com/ethereum/EIPs/blob/master/EIPS/eip-1052.md @@ -152,12 +155,12 @@ def extcodehash(computation): computation.stack_push(account_db.get_code_hash(account)) -def returndatasize(computation): +def returndatasize(computation: BaseComputation) -> None: size = len(computation.return_data) computation.stack_push(size) -def returndatacopy(computation): +def returndatacopy(computation: BaseComputation) -> None: ( mem_start_position, returndata_start_position, diff --git a/eth/vm/logic/duplication.py b/eth/vm/logic/duplication.py index 1fcbfc4d55..4d4b75eb27 100644 --- a/eth/vm/logic/duplication.py +++ b/eth/vm/logic/duplication.py @@ -1,7 +1,9 @@ import functools +from eth.vm.computation import BaseComputation -def dup_XX(computation, position): + +def dup_XX(computation: BaseComputation, position: int) -> None: """ Stack item duplication. """ diff --git a/eth/vm/logic/flow.py b/eth/vm/logic/flow.py index 88a32d5e76..e30294b47c 100644 --- a/eth/vm/logic/flow.py +++ b/eth/vm/logic/flow.py @@ -4,16 +4,18 @@ InvalidInstruction, Halt, ) + +from eth.vm.computation import BaseComputation from eth.vm.opcode_values import ( JUMPDEST, ) -def stop(computation): +def stop(computation: BaseComputation) -> None: raise Halt('STOP') -def jump(computation): +def jump(computation: BaseComputation) -> None: jump_dest = computation.stack_pop(type_hint=constants.UINT256) computation.code.pc = jump_dest @@ -27,7 +29,7 @@ def jump(computation): raise InvalidInstruction("Jump resulted in invalid instruction") -def jumpi(computation): +def jumpi(computation: BaseComputation) -> None: jump_dest, check_value = computation.stack_pop(num_items=2, type_hint=constants.UINT256) if check_value: @@ -42,17 +44,17 @@ def jumpi(computation): raise InvalidInstruction("Jump resulted in invalid instruction") -def jumpdest(computation): +def jumpdest(computation: BaseComputation) -> None: pass -def pc(computation): +def pc(computation: BaseComputation) -> None: pc = max(computation.code.pc - 1, 0) computation.stack_push(pc) -def gas(computation): +def gas(computation: BaseComputation) -> None: gas_remaining = computation.get_gas_remaining() computation.stack_push(gas_remaining) diff --git a/eth/vm/logic/invalid.py b/eth/vm/logic/invalid.py index 0d60b3be73..34a2c944d9 100644 --- a/eth/vm/logic/invalid.py +++ b/eth/vm/logic/invalid.py @@ -1,16 +1,23 @@ +from typing import ( + TYPE_CHECKING, +) + from eth.exceptions import InvalidInstruction from eth.vm.opcode import Opcode +if TYPE_CHECKING: + from eth.vm.computation import BaseComputation # noqa: F401 + class InvalidOpcode(Opcode): mnemonic = "INVALID" gas_cost = 0 - def __init__(self, value): + def __init__(self, value: int) -> None: self.value = value super().__init__() - def __call__(self, computation): + def __call__(self, computation: 'BaseComputation') -> None: raise InvalidInstruction("Invalid opcode 0x{0:x} @ {1}".format( self.value, computation.code.pc - 1, diff --git a/eth/vm/logic/logging.py b/eth/vm/logic/logging.py index 46eb33efc9..ce941ef5b7 100644 --- a/eth/vm/logic/logging.py +++ b/eth/vm/logic/logging.py @@ -4,8 +4,10 @@ from eth import constants +from eth.vm.computation import BaseComputation -def log_XX(computation, topic_count): + +def log_XX(computation: BaseComputation, topic_count: int) -> None: if topic_count < 0 or topic_count > 4: raise TypeError("Invalid log topic size. Must be 0, 1, 2, 3, or 4") diff --git a/eth/vm/logic/memory.py b/eth/vm/logic/memory.py index ef8ab32514..30a395ff22 100644 --- a/eth/vm/logic/memory.py +++ b/eth/vm/logic/memory.py @@ -1,7 +1,9 @@ from eth import constants +from eth.vm.computation import BaseComputation -def mstore(computation): + +def mstore(computation: BaseComputation) -> None: start_position = computation.stack_pop(type_hint=constants.UINT256) value = computation.stack_pop(type_hint=constants.BYTES) @@ -13,7 +15,7 @@ def mstore(computation): computation.memory_write(start_position, 32, normalized_value) -def mstore8(computation): +def mstore8(computation: BaseComputation) -> None: start_position = computation.stack_pop(type_hint=constants.UINT256) value = computation.stack_pop(type_hint=constants.BYTES) @@ -25,7 +27,7 @@ def mstore8(computation): computation.memory_write(start_position, 1, normalized_value) -def mload(computation): +def mload(computation: BaseComputation) -> None: start_position = computation.stack_pop(type_hint=constants.UINT256) computation.extend_memory(start_position, 32) @@ -34,5 +36,5 @@ def mload(computation): computation.stack_push(value) -def msize(computation): +def msize(computation: BaseComputation) -> None: computation.stack_push(len(computation._memory)) diff --git a/eth/vm/logic/sha3.py b/eth/vm/logic/sha3.py index 58092b7fe4..aee319e28f 100644 --- a/eth/vm/logic/sha3.py +++ b/eth/vm/logic/sha3.py @@ -4,9 +4,10 @@ from eth.utils.numeric import ( ceil32, ) +from eth.vm.computation import BaseComputation -def sha3(computation): +def sha3(computation: BaseComputation) -> None: start_position, size = computation.stack_pop(num_items=2, type_hint=constants.UINT256) computation.extend_memory(start_position, size) diff --git a/eth/vm/logic/stack.py b/eth/vm/logic/stack.py index 0a43af5ff9..b5e5aaa237 100644 --- a/eth/vm/logic/stack.py +++ b/eth/vm/logic/stack.py @@ -2,12 +2,14 @@ from eth import constants +from eth.vm.computation import BaseComputation -def pop(computation): + +def pop(computation: BaseComputation) -> None: computation.stack_pop(type_hint=constants.ANY) -def push_XX(computation, size): +def push_XX(computation: BaseComputation, size: int) -> None: raw_value = computation.code.read(size) if not raw_value.strip(b'\x00'): diff --git a/eth/vm/logic/storage.py b/eth/vm/logic/storage.py index 60b05e77e9..42609092be 100644 --- a/eth/vm/logic/storage.py +++ b/eth/vm/logic/storage.py @@ -4,8 +4,10 @@ encode_hex, ) +from eth.vm.computation import BaseComputation -def sstore(computation): + +def sstore(computation: BaseComputation) -> None: slot, value = computation.stack_pop(num_items=2, type_hint=constants.UINT256) current_value = computation.state.account_db.get_storage( @@ -49,7 +51,7 @@ def sstore(computation): ) -def sload(computation): +def sload(computation: BaseComputation) -> None: slot = computation.stack_pop(type_hint=constants.UINT256) value = computation.state.account_db.get_storage( diff --git a/eth/vm/logic/swap.py b/eth/vm/logic/swap.py index ce1be08abf..b2b75d0c1a 100644 --- a/eth/vm/logic/swap.py +++ b/eth/vm/logic/swap.py @@ -1,7 +1,9 @@ import functools +from eth.vm.computation import BaseComputation -def swap_XX(computation, position): + +def swap_XX(computation: BaseComputation, position: int) -> None: """ Stack item swapping """ diff --git a/eth/vm/memory.py b/eth/vm/memory.py index 7b52ce4c97..5e08e7147b 100644 --- a/eth/vm/memory.py +++ b/eth/vm/memory.py @@ -20,7 +20,7 @@ class Memory(object): __slots__ = ['_bytes'] logger = logging.getLogger('eth.vm.memory.Memory') - def __init__(self): + def __init__(self) -> None: self._bytes = bytearray() def extend(self, start_position: int, size: int) -> None: diff --git a/eth/vm/message.py b/eth/vm/message.py index 9a4bee3434..4969646582 100644 --- a/eth/vm/message.py +++ b/eth/vm/message.py @@ -1,5 +1,7 @@ import logging +from eth_typing import Address + from eth.constants import ( CREATE_CONTRACT_ADDRESS, ) @@ -25,17 +27,17 @@ class Message(object): logger = logging.getLogger('eth.vm.message.Message') def __init__(self, - gas, - to, - sender, - value, - data, - code, - depth=0, - create_address=None, - code_address=None, - should_transfer_value=True, - is_static=False): + gas: int, + to: Address, + sender: Address, + value: int, + data: bytes, + code: bytes, + depth: int=0, + create_address: Address=None, + code_address: Address=None, + should_transfer_value: bool=True, + is_static: bool=False) -> None: validate_uint256(gas, title="Message.gas") self.gas = gas # type: int @@ -74,27 +76,27 @@ def __init__(self, self.is_static = is_static @property - def code_address(self): + def code_address(self) -> Address: if self._code_address is not None: return self._code_address else: return self.to @code_address.setter - def code_address(self, value): + def code_address(self, value: Address) -> None: self._code_address = value @property - def storage_address(self): + def storage_address(self) -> Address: if self._storage_address is not None: return self._storage_address else: return self.to @storage_address.setter - def storage_address(self, value): + def storage_address(self, value: Address) -> None: self._storage_address = value @property - def is_create(self): + def is_create(self) -> bool: return self.to == CREATE_CONTRACT_ADDRESS diff --git a/eth/vm/opcode.py b/eth/vm/opcode.py index ddc659c0e0..0d63410d03 100644 --- a/eth/vm/opcode.py +++ b/eth/vm/opcode.py @@ -6,38 +6,59 @@ abstractmethod ) +from typing import ( + Any, + Callable, + cast, + Type, + TypeVar, + TYPE_CHECKING, +) + +from eth.tools.logging import TraceLogger + from eth.utils.datatypes import Configurable +if TYPE_CHECKING: + from computation import BaseComputation # noqa: F401 + + +T = TypeVar('T') + class Opcode(Configurable, ABC): mnemonic = None # type: str gas_cost = None # type: int - def __init__(self): + def __init__(self) -> None: if self.mnemonic is None: raise TypeError("Opcode class {0} missing opcode mnemonic".format(type(self))) if self.gas_cost is None: raise TypeError("Opcode class {0} missing opcode gas_cost".format(type(self))) @abstractmethod - def __call__(self, computation): + def __call__(self, computation: 'BaseComputation') -> Any: """ Hook for performing the actual VM execution. """ raise NotImplementedError("Must be implemented by subclasses") @property - def logger(self): - return logging.getLogger('eth.vm.logic.{0}'.format(self.mnemonic)) + def logger(self) -> TraceLogger: + logger_obj = logging.getLogger('eth.vm.logic.{0}'.format(self.mnemonic)) + return cast(TraceLogger, logger_obj) @classmethod - def as_opcode(cls, logic_fn, mnemonic, gas_cost): + def as_opcode(cls: Type[T], + logic_fn: Callable[..., Any], + mnemonic: str, + gas_cost: int) -> Type[T]: """ Class factory method for turning vanilla functions into Opcode classes. """ if gas_cost: @functools.wraps(logic_fn) - def wrapped_logic_fn(computation): + def wrapped_logic_fn(computation: 'BaseComputation') -> Any: """ Wrapper functionf or the logic function which consumes the base opcode gas cost prior to execution. @@ -58,10 +79,10 @@ def wrapped_logic_fn(computation): opcode_cls = type("opcode:{0}".format(mnemonic), (cls,), props) return opcode_cls() - def __copy__(self): + def __copy__(self) -> 'Opcode': return type(self)() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> 'Opcode': return type(self)() diff --git a/eth/vm/stack.py b/eth/vm/stack.py index 7bf211e283..2c9028b205 100644 --- a/eth/vm/stack.py +++ b/eth/vm/stack.py @@ -14,7 +14,9 @@ ) from typing import ( # noqa: F401 + Generator, List, + Tuple, Union ) from eth_typing import Hash32 # noqa: F401 @@ -27,13 +29,13 @@ class Stack(object): __slots__ = ['values'] logger = logging.getLogger('eth.vm.stack.Stack') - def __init__(self): - self.values = [] # type: List[Union[int, Hash32]] + def __init__(self) -> None: + self.values = [] # type: List[Union[int, bytes]] - def __len__(self): + def __len__(self) -> int: return len(self.values) - def push(self, value): + def push(self, value: Union[int, bytes]) -> None: """ Push an item onto the stack. """ @@ -44,9 +46,11 @@ def push(self, value): self.values.append(value) - def pop(self, num_items, type_hint): + def pop(self, + num_items: int, + type_hint: str) -> Union[int, bytes, Tuple[Union[int, bytes], ...]]: """ - Pop an item off thes stack. + Pop an item off the stack. Note: This function is optimized for speed over readability. """ @@ -58,7 +62,7 @@ def pop(self, num_items, type_hint): except IndexError: raise InsufficientStack("No stack items") - def _pop(self, num_items, type_hint): + def _pop(self, num_items: int, type_hint: str) -> Generator[Union[int, bytes], None, None]: for _ in range(num_items): if type_hint == constants.UINT256: value = self.values.pop() @@ -82,7 +86,7 @@ def _pop(self, num_items, type_hint): ) ) - def swap(self, position): + def swap(self, position: int) -> None: """ Perform a SWAP operation on the stack. """ @@ -92,7 +96,7 @@ def swap(self, position): except IndexError: raise InsufficientStack("Insufficient stack items for SWAP{0}".format(position)) - def dup(self, position): + def dup(self, position: int) -> None: """ Perform a DUP operation on the stack. """ diff --git a/eth/vm/state.py b/eth/vm/state.py index 1a1295f6f6..056177ec94 100644 --- a/eth/vm/state.py +++ b/eth/vm/state.py @@ -5,27 +5,50 @@ import contextlib import logging from typing import ( # noqa: F401 + cast, + Callable, + Iterator, + Tuple, Type, - TYPE_CHECKING + TYPE_CHECKING, +) +from uuid import UUID + +from eth_typing import ( + Address, + Hash32, ) from eth.constants import ( BLANK_ROOT_HASH, MAX_PREV_HEADER_DEPTH, ) -from eth.exceptions import StateRootNotFound from eth.db.account import ( # noqa: F401 BaseAccountDB, AccountDB, ) +from eth.db.backends.base import ( + BaseDB, +) +from eth.exceptions import StateRootNotFound +from eth.tools.logging import ( + TraceLogger, +) from eth.utils.datatypes import ( Configurable, ) +from eth.vm.execution_context import ( + ExecutionContext, +) +from eth.vm.message import Message if TYPE_CHECKING: from eth.computation import ( # noqa: F401 BaseComputation, ) + from eth.rlp.transactions import ( # noqa: F401 + BaseTransaction, + ) from eth.vm.transaction_context import ( # noqa: F401 BaseTransactionContext, ) @@ -57,7 +80,7 @@ class for vm execution. account_db_class = None # type: Type[BaseAccountDB] transaction_executor = None # type: Type[BaseTransactionExecutor] - def __init__(self, db, execution_context, state_root): + def __init__(self, db: BaseDB, execution_context: ExecutionContext, state_root: bytes) -> None: self._db = db self.execution_context = execution_context self.account_db = self.get_account_db_class()(self._db, state_root) @@ -66,43 +89,44 @@ def __init__(self, db, execution_context, state_root): # Logging # @property - def logger(self): - return logging.getLogger('eth.vm.state.{0}'.format(self.__class__.__name__)) + def logger(self) -> TraceLogger: + normal_logger = logging.getLogger('eth.vm.state.{0}'.format(self.__class__.__name__)) + return cast(TraceLogger, normal_logger) # # Block Object Properties (in opcodes) # @property - def coinbase(self): + def coinbase(self) -> Address: """ Return the current ``coinbase`` from the current :attr:`~execution_context` """ return self.execution_context.coinbase @property - def timestamp(self): + def timestamp(self) -> int: """ Return the current ``timestamp`` from the current :attr:`~execution_context` """ return self.execution_context.timestamp @property - def block_number(self): + def block_number(self) -> int: """ Return the current ``block_number`` from the current :attr:`~execution_context` """ return self.execution_context.block_number @property - def difficulty(self): + def difficulty(self) -> int: """ Return the current ``difficulty`` from the current :attr:`~execution_context` """ return self.execution_context.difficulty @property - def gas_limit(self): + def gas_limit(self) -> int: """ Return the current ``gas_limit`` from the current :attr:`~transaction_context` """ @@ -112,7 +136,7 @@ def gas_limit(self): # Access to account db # @classmethod - def get_account_db_class(cls): + def get_account_db_class(cls) -> Type[BaseAccountDB]: """ Return the :class:`~eth.db.account.BaseAccountDB` class that the state class uses. @@ -122,7 +146,7 @@ def get_account_db_class(cls): return cls.account_db_class @property - def state_root(self): + def state_root(self) -> bytes: """ Return the current ``state_root`` from the underlying database """ @@ -131,7 +155,7 @@ def state_root(self): # # Access self._chaindb # - def snapshot(self): + def snapshot(self) -> Tuple[bytes, Tuple[UUID, UUID]]: """ Perform a full snapshot of the current state. @@ -140,7 +164,7 @@ def snapshot(self): """ return (self.state_root, self.account_db.record()) - def revert(self, snapshot): + def revert(self, snapshot: Tuple[bytes, Tuple[UUID, UUID]]) -> None: """ Revert the VM to the state at the snapshot """ @@ -151,7 +175,7 @@ def revert(self, snapshot): # now roll the underlying database back self.account_db.discard(changeset_id) - def commit(self, snapshot): + def commit(self, snapshot: Tuple[bytes, Tuple[UUID, UUID]]) -> None: """ Commit the journal to the point where the snapshot was taken. This will merge in any changesets that were recorded *after* the snapshot changeset. @@ -162,7 +186,7 @@ def commit(self, snapshot): # # Access self.prev_hashes (Read-only) # - def get_ancestor_hash(self, block_number): + def get_ancestor_hash(self, block_number: int) -> Hash32: """ Return the hash for the ancestor block with number ``block_number``. Return the empty bytestring ``b''`` if the block number is outside of the @@ -175,14 +199,16 @@ def get_ancestor_hash(self, block_number): ancestor_depth >= len(self.execution_context.prev_hashes) ) if is_ancestor_depth_out_of_range: - return b'' + return Hash32(b'') ancestor_hash = self.execution_context.prev_hashes[ancestor_depth] return ancestor_hash # # Computation # - def get_computation(self, message, transaction_context): + def get_computation(self, + message: Message, + transaction_context: 'BaseTransactionContext') -> 'BaseComputation': """ Return a computation instance for the given `message` and `transaction_context` """ @@ -196,7 +222,7 @@ def get_computation(self, message, transaction_context): # Transaction context # @classmethod - def get_transaction_context_class(cls): + def get_transaction_context_class(cls) -> Type['BaseTransactionContext']: """ Return the :class:`~eth.vm.transaction_context.BaseTransactionContext` class that the state class uses. @@ -208,7 +234,7 @@ def get_transaction_context_class(cls): # # Execution # - def apply_transaction(self, transaction): + def apply_transaction(self, transaction: 'BaseTransaction') -> Tuple[bytes, 'BaseComputation']: """ Apply transaction to the vm state @@ -221,19 +247,19 @@ def apply_transaction(self, transaction): state_root = self.account_db.make_state_root() return state_root, computation - def get_transaction_executor(self): + def get_transaction_executor(self) -> 'BaseTransactionExecutor': return self.transaction_executor(self) - def costless_execute_transaction(self, transaction): + def costless_execute_transaction(self, transaction: 'BaseTransaction') -> 'BaseComputation': with self.override_transaction_context(gas_price=transaction.gas_price): free_transaction = transaction.copy(gas_price=0) return self.execute_transaction(free_transaction) @contextlib.contextmanager - def override_transaction_context(self, gas_price): + def override_transaction_context(self, gas_price: int) -> Iterator[None]: original_context = self.get_transaction_context - def get_custom_transaction_context(transaction): + def get_custom_transaction_context(transaction: 'BaseTransaction') -> 'BaseTransactionContext': # noqa: E501 custom_transaction = transaction.copy(gas_price=gas_price) return original_context(custom_transaction) @@ -241,18 +267,18 @@ def get_custom_transaction_context(transaction): try: yield finally: - self.get_transaction_context = original_context + self.get_transaction_context = original_context # type: ignore # Remove ignore if https://github.com/python/mypy/issues/708 is fixed. # noqa: E501 @abstractmethod - def execute_transaction(self, transaction): + def execute_transaction(self, transaction: 'BaseTransaction') -> 'BaseComputation': raise NotImplementedError() @abstractmethod - def validate_transaction(self, transaction): + def validate_transaction(self, transaction: 'BaseTransaction') -> None: raise NotImplementedError @classmethod - def get_transaction_context(cls, transaction): + def get_transaction_context(cls, transaction: 'BaseTransaction') -> 'BaseTransactionContext': return cls.get_transaction_context_class()( gas_price=transaction.gas_price, origin=transaction.sender, @@ -260,10 +286,10 @@ def get_transaction_context(cls, transaction): class BaseTransactionExecutor(ABC): - def __init__(self, vm_state): + def __init__(self, vm_state: BaseState) -> None: self.vm_state = vm_state - def __call__(self, transaction): + def __call__(self, transaction: 'BaseTransaction') -> 'BaseComputation': valid_transaction = self.validate_transaction(transaction) message = self.build_evm_message(valid_transaction) computation = self.build_computation(message, valid_transaction) @@ -271,17 +297,21 @@ def __call__(self, transaction): return finalized_computation @abstractmethod - def validate_transaction(self): + def validate_transaction(self, transaction: 'BaseTransaction') -> 'BaseTransaction': raise NotImplementedError @abstractmethod - def build_evm_message(self): + def build_evm_message(self, transaction: 'BaseTransaction') -> Message: raise NotImplementedError() @abstractmethod - def build_computation(self): + def build_computation(self, + message: Message, + transaction: 'BaseTransaction') -> 'BaseComputation': raise NotImplementedError() @abstractmethod - def finalize_computation(self): + def finalize_computation(self, + transaction: 'BaseTransaction', + computation: 'BaseComputation') -> 'BaseComputation': raise NotImplementedError() diff --git a/eth/vm/transaction_context.py b/eth/vm/transaction_context.py index 2bdd543279..d8e0c56bd0 100644 --- a/eth/vm/transaction_context.py +++ b/eth/vm/transaction_context.py @@ -1,5 +1,7 @@ import itertools +from eth_typing import Address + from eth.validation import ( validate_canonical_address, validate_uint256, @@ -13,20 +15,20 @@ class BaseTransactionContext: """ __slots__ = ['_gas_price', '_origin', '_log_counter'] - def __init__(self, gas_price, origin): + def __init__(self, gas_price: int, origin: Address) -> None: validate_uint256(gas_price, title="TransactionContext.gas_price") self._gas_price = gas_price validate_canonical_address(origin, title="TransactionContext.origin") self._origin = origin self._log_counter = itertools.count() - def get_next_log_counter(self): + def get_next_log_counter(self) -> int: return next(self._log_counter) @property - def gas_price(self): + def gas_price(self) -> int: return self._gas_price @property - def origin(self): + def origin(self) -> Address: return self._origin diff --git a/tests/trinity/plugins/test_examples.py b/tests/trinity/plugins/test_examples.py new file mode 100644 index 0000000000..f950a6fa33 --- /dev/null +++ b/tests/trinity/plugins/test_examples.py @@ -0,0 +1,8 @@ +from trinity.plugins.examples import ( + PeerCountReporterPlugin, +) + + +def test_can_instantiate_examples(): + plugin = PeerCountReporterPlugin() + assert plugin.name == "Peer Count Reporter" diff --git a/tox.ini b/tox.ini index 0cd44700d4..b29ded06d3 100644 --- a/tox.ini +++ b/tox.ini @@ -107,6 +107,7 @@ commands= mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs -p eth mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs --disallow-any-generics -p eth.utils mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs --disallow-any-generics -p eth.tools + mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs --disallow-any-generics -p eth.vm [testenv:py36-lint] diff --git a/trinity/chains/light.py b/trinity/chains/light.py index 8f39634b99..ce06ba22d2 100644 --- a/trinity/chains/light.py +++ b/trinity/chains/light.py @@ -17,6 +17,7 @@ ) from eth_typing import ( + Address, BlockNumber, Hash32, ) @@ -184,9 +185,14 @@ def build_block_with_transactions( def create_transaction(self, *args: Any, **kwargs: Any) -> BaseTransaction: raise NotImplementedError("Chain classes must implement " + inspect.stack()[0][3]) - def create_unsigned_transaction(self, - *args: Any, - **kwargs: Any) -> BaseUnsignedTransaction: + def create_unsigned_transaction(cls, + *, + nonce: int, + gas_price: int, + gas: int, + to: Address, + value: int, + data: bytes) -> BaseUnsignedTransaction: raise NotImplementedError("Chain classes must implement " + inspect.stack()[0][3]) def get_canonical_transaction(self, transaction_hash: Hash32) -> BaseTransaction: diff --git a/trinity/cli_parser.py b/trinity/cli_parser.py index 09e222b1ab..3ced52fc6b 100644 --- a/trinity/cli_parser.py +++ b/trinity/cli_parser.py @@ -264,12 +264,7 @@ def __call__(self, '--nodekey', help=( "Hexadecimal encoded private key to use for the nodekey" - ) -) -chain_parser.add_argument( - '--nodekey-path', - help=( - "The filesystem path to the file which contains the nodekey" + " or the filesystem path to the file which contains the nodekey" ) ) diff --git a/trinity/extensibility/exceptions.py b/trinity/extensibility/exceptions.py index 20751ca257..036a93db9b 100644 --- a/trinity/extensibility/exceptions.py +++ b/trinity/extensibility/exceptions.py @@ -6,7 +6,15 @@ class EventBusNotReady(BaseTrinityError): """ Raised when a plugin tried to access an :class:`~lahja.eventbus.EventBus` before the plugin - had received its :meth:`~trinity.extensibility.plugin.BasePlugin.ready` call. + had received its :meth:`~trinity.extensibility.plugin.BasePlugin.on_ready` call. + """ + pass + + +class InvalidPluginStatus(BaseTrinityError): + """ + Raised when it was attempted to perform an action while the current + :class:`~trinity.extensibility.plugin.PluginStatus` does not allow to perform such action. """ pass diff --git a/trinity/extensibility/plugin.py b/trinity/extensibility/plugin.py index b4698a697a..c693ca49f3 100644 --- a/trinity/extensibility/plugin.py +++ b/trinity/extensibility/plugin.py @@ -8,6 +8,10 @@ _SubParsersAction, ) import asyncio +from enum import ( + auto, + Enum, +) import logging from multiprocessing import ( Process @@ -38,6 +42,7 @@ ) from trinity.extensibility.exceptions import ( EventBusNotReady, + InvalidPluginStatus, ) from trinity.utils.ipc import ( kill_process_gracefully @@ -50,6 +55,16 @@ ) +class PluginStatus(Enum): + NOT_READY = auto() + READY = auto() + STARTED = auto() + STOPPED = auto() + + +INVALID_START_STATUS = (PluginStatus.NOT_READY, PluginStatus.STARTED,) + + class TrinityBootInfo(NamedTuple): args: Namespace trinity_config: TrinityConfig @@ -65,7 +80,7 @@ class PluginContext: The :class:`~trinity.extensibility.plugin.PluginContext` is set during startup and is guaranteed to exist by the time that a plugin receives its - :meth:`~trinity.extensibility.plugin.BasePlugin.ready` call. + :meth:`~trinity.extensibility.plugin.BasePlugin.on_ready` call. """ def __init__(self, endpoint: Endpoint, boot_info: TrinityBootInfo) -> None: @@ -112,7 +127,7 @@ def trinity_config(self) -> TrinityConfig: class BasePlugin(ABC): context: PluginContext = None - running: bool = False + status: PluginStatus = PluginStatus.NOT_READY @property @abstractmethod @@ -140,6 +155,13 @@ def event_bus(self) -> Endpoint: return self.context.event_bus + @property + def running(self) -> bool: + """ + Return ``True`` if the ``status`` is ``PluginStatus.STARTED``, otherwise return ``False``. + """ + return self.status is PluginStatus.STARTED + def set_context(self, context: PluginContext) -> None: """ Set the :class:`~trinity.extensibility.plugin.PluginContext` for this plugin. @@ -147,6 +169,14 @@ def set_context(self, context: PluginContext) -> None: self.context = context def ready(self) -> None: + """ + Set the ``status`` to ``PluginStatus.READY`` and delegate to + :meth:`~trinity.extensibility.plugin.BasePlugin.on_ready` + """ + self.status = PluginStatus.READY + self.on_ready() + + def on_ready(self) -> None: """ Notify the plugin that it is ready to bootstrap itself. Plugins can rely on the :class:`~trinity.extensibility.plugin.PluginContext` to be set @@ -157,24 +187,30 @@ def ready(self) -> None: def configure_parser(self, arg_parser: ArgumentParser, subparser: _SubParsersAction) -> None: """ Give the plugin a chance to amend the Trinity CLI argument parser. This hook is called - before :meth:`~trinity.extensibility.plugin.BasePlugin.ready` + before :meth:`~trinity.extensibility.plugin.BasePlugin.on_ready` """ pass def start(self) -> None: """ - Delegate to :meth:`~trinity.extensibility.plugin.BasePlugin._start` and set ``running`` + Delegate to :meth:`~trinity.extensibility.plugin.BasePlugin.do_start` and set ``running`` to ``True``. Broadcast a :class:`~trinity.extensibility.events.PluginStartedEvent` on the :class:`~lahja.eventbus.EventBus` and hence allow other plugins to act accordingly. """ - self.running = True - self._start() + + if self.status in INVALID_START_STATUS: + raise InvalidPluginStatus( + f"Can not start plugin when the plugin status is {self.status}" + ) + + self.status = PluginStatus.STARTED + self.do_start() self.event_bus.broadcast( PluginStartedEvent(type(self)) ) self.logger.info("Plugin started: %s", self.name) - def _start(self) -> None: + def do_start(self) -> None: """ Perform the actual plugin start routine. In the case of a `BaseIsolatedPlugin` this method will be called in a separate process. @@ -190,7 +226,7 @@ class BaseSyncStopPlugin(BasePlugin): A :class:`~trinity.extensibility.plugin.BaseSyncStopPlugin` unwinds synchronoulsy, hence blocks until the shutdown is done. """ - def _stop(self) -> None: + def do_stop(self) -> None: """ Stop the plugin. Should be overwritten by subclasses. """ @@ -198,11 +234,11 @@ def _stop(self) -> None: def stop(self) -> None: """ - Delegate to :meth:`~trinity.extensibility.plugin.BaseSyncStopPlugin._stop` causing the + Delegate to :meth:`~trinity.extensibility.plugin.BaseSyncStopPlugin.do_stop` causing the plugin to stop and setting ``running`` to ``False``. """ - self._stop() - self.running = False + self.do_stop() + self.status = PluginStatus.STOPPED class BaseAsyncStopPlugin(BasePlugin): @@ -211,7 +247,7 @@ class BaseAsyncStopPlugin(BasePlugin): needs to be awaited. """ - async def _stop(self) -> None: + async def do_stop(self) -> None: """ Asynchronously stop the plugin. Should be overwritten by subclasses. """ @@ -219,11 +255,11 @@ async def _stop(self) -> None: async def stop(self) -> None: """ - Delegate to :meth:`~trinity.extensibility.plugin.BaseAsyncStopPlugin._stop` causing the + Delegate to :meth:`~trinity.extensibility.plugin.BaseAsyncStopPlugin.do_stop` causing the plugin to stop asynchronously and setting ``running`` to ``False``. """ - await self._stop() - self.running = False + await self.do_stop() + self.status = PluginStatus.STOPPED class BaseMainProcessPlugin(BasePlugin): @@ -250,9 +286,9 @@ class BaseIsolatedPlugin(BaseSyncStopPlugin): def start(self) -> None: """ - Prepare the plugin to get started and eventually call ``_start`` in a separate process. + Prepare the plugin to get started and eventually call ``do_start`` in a separate process. """ - self.running = True + self.status = PluginStatus.STARTED self._process = ctx.Process( target=self._prepare_start, ) @@ -268,9 +304,9 @@ def _prepare_start(self) -> None: self.event_bus.broadcast( PluginStartedEvent(type(self)) ) - self._start() + self.do_start() - def _stop(self) -> None: + def do_stop(self) -> None: self.context.event_bus.stop() kill_process_gracefully(self._process, self.logger) @@ -290,7 +326,7 @@ def configure_parser(self, arg_parser: ArgumentParser, subparser: _SubParsersAct def handle_event(self, activation_event: BaseEvent) -> None: self.logger.info("Debug plugin: handle_event called: %s", activation_event) - def _start(self) -> None: + def do_start(self) -> None: self.logger.info("Debug plugin: start called") asyncio.ensure_future(self.count_forever()) @@ -301,5 +337,5 @@ async def count_forever(self) -> None: i += 1 await asyncio.sleep(1) - async def _stop(self) -> None: + async def do_stop(self) -> None: self.logger.info("Debug plugin: stop called") diff --git a/trinity/plugins/builtin/ethstats/plugin.py b/trinity/plugins/builtin/ethstats/plugin.py index a3344ee65f..b36ce35a9d 100644 --- a/trinity/plugins/builtin/ethstats/plugin.py +++ b/trinity/plugins/builtin/ethstats/plugin.py @@ -71,7 +71,7 @@ def configure_parser(self, arg_parser: ArgumentParser, subparser: _SubParsersAct default=os.environ.get('ETHSTATS_NODE_CONTACT', ''), ) - def ready(self) -> None: + def on_ready(self) -> None: args = self.context.args if not args.ethstats: @@ -103,7 +103,7 @@ def ready(self) -> None: self.start() - def _start(self) -> None: + def do_start(self) -> None: service = EthstatsService( self.context, self.server_url, diff --git a/trinity/plugins/builtin/json_rpc/plugin.py b/trinity/plugins/builtin/json_rpc/plugin.py index e1e6bfd6e3..600927b57b 100644 --- a/trinity/plugins/builtin/json_rpc/plugin.py +++ b/trinity/plugins/builtin/json_rpc/plugin.py @@ -33,7 +33,7 @@ class JsonRpcServerPlugin(BaseIsolatedPlugin): def name(self) -> str: return "JSON-RPC Server" - def ready(self) -> None: + def on_ready(self) -> None: if not self.context.args.disable_rpc: self.start() @@ -44,7 +44,7 @@ def configure_parser(self, arg_parser: ArgumentParser, subparser: _SubParsersAct help="Disables the JSON-RPC Server", ) - def _start(self) -> None: + def do_start(self) -> None: db_manager = create_db_manager(self.context.trinity_config.database_ipc_path) db_manager.connect() diff --git a/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py b/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py index ad0b900a2b..0ee3bc892b 100644 --- a/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py +++ b/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py @@ -40,7 +40,7 @@ class LightPeerChainBridgePlugin(BaseAsyncStopPlugin): def name(self) -> str: return "LightPeerChain Bridge" - def ready(self) -> None: + def on_ready(self) -> None: if self.context.trinity_config.sync_mode != SYNC_LIGHT: return @@ -57,12 +57,12 @@ def handle_event(self, event: ResourceAvailableEvent) -> None: self.chain = event.resource self.start() - def _start(self) -> None: + def do_start(self) -> None: chain = cast(LightDispatchChain, self.chain) self.handler = LightPeerChainEventBusHandler(chain._peer_chain, self.context.event_bus) asyncio.ensure_future(self.handler.run()) - async def _stop(self) -> None: + async def do_stop(self) -> None: # This isn't really needed for the standard shutdown case as the LightPeerChain will # automatically shutdown whenever the `CancelToken` it was chained with is triggered. # It may still be useful to stop the LightPeerChain Bridge plugin individually though. diff --git a/trinity/plugins/builtin/tx_pool/plugin.py b/trinity/plugins/builtin/tx_pool/plugin.py index f611eb6d3c..e3d984a1b5 100644 --- a/trinity/plugins/builtin/tx_pool/plugin.py +++ b/trinity/plugins/builtin/tx_pool/plugin.py @@ -54,7 +54,7 @@ def configure_parser(self, arg_parser: ArgumentParser, subparser: _SubParsersAct help="Enables the Transaction Pool (experimental)", ) - def ready(self) -> None: + def on_ready(self) -> None: light_mode = self.context.args.sync_mode == SYNC_LIGHT self.is_enabled = self.context.args.tx_pool and not light_mode @@ -78,7 +78,7 @@ def handle_event(self, event: ResourceAvailableEvent) -> None: if all((self.peer_pool is not None, self.chain is not None, self.is_enabled)): self.start() - def _start(self) -> None: + def do_start(self) -> None: if isinstance(self.chain, BaseMainnetChain): validator = DefaultTransactionValidator(self.chain, BYZANTIUM_MAINNET_BLOCK) elif isinstance(self.chain, BaseRopstenChain): @@ -91,7 +91,7 @@ def _start(self) -> None: self.tx_pool = TxPool(self.peer_pool, validator, self.cancel_token) asyncio.ensure_future(self.tx_pool.run()) - async def _stop(self) -> None: + async def do_stop(self) -> None: # This isn't really needed for the standard shutdown case as the TxPool will automatically # shutdown whenever the `CancelToken` it was chained with is triggered. It may still be # useful to stop the TxPool plugin individually though. diff --git a/trinity/plugins/builtin/tx_pool/pool.py b/trinity/plugins/builtin/tx_pool/pool.py index de961061a0..bb0a532733 100644 --- a/trinity/plugins/builtin/tx_pool/pool.py +++ b/trinity/plugins/builtin/tx_pool/pool.py @@ -122,5 +122,5 @@ def _add_txs_to_bloom(self, peer: ETHPeer, txs: Iterable[BaseTransactionFields]) for val in txs: self._bloom.add(self._construct_bloom_entry(peer, val)) - async def _cleanup(self) -> None: + async def do_cleanup(self) -> None: self.logger.info("Stopping Tx Pool...") diff --git a/trinity/plugins/examples/__init__.py b/trinity/plugins/examples/__init__.py new file mode 100644 index 0000000000..5ffe92fb96 --- /dev/null +++ b/trinity/plugins/examples/__init__.py @@ -0,0 +1 @@ +from .peer_count_reporter.plugin import PeerCountReporterPlugin # noqa: F401 diff --git a/trinity/plugins/examples/peer_count_reporter/__init__.py b/trinity/plugins/examples/peer_count_reporter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/plugins/examples/peer_count_reporter/plugin.py b/trinity/plugins/examples/peer_count_reporter/plugin.py new file mode 100644 index 0000000000..c27526f756 --- /dev/null +++ b/trinity/plugins/examples/peer_count_reporter/plugin.py @@ -0,0 +1,20 @@ +# This might end up as a temporary place for this. The following code is +# included in the documentation (literalinclude!) and uses a more concise +# form of imports. + +from argparse import ArgumentParser, _SubParsersAction + +from trinity.extensibility import BaseIsolatedPlugin + + +# --START CLASS-- +class PeerCountReporterPlugin(BaseIsolatedPlugin): + + @property + def name(self) -> str: + return "Peer Count Reporter" + + def configure_parser(self, + arg_parser: ArgumentParser, + subparser: _SubParsersAction) -> None: + arg_parser.add_argument("--report-peer-count", type=bool, required=False) diff --git a/trinity/utils/chains.py b/trinity/utils/chains.py index cb4788ffb8..5965303955 100644 --- a/trinity/utils/chains.py +++ b/trinity/utils/chains.py @@ -135,12 +135,11 @@ def construct_trinity_config_params( if args.data_dir is not None: yield 'data_dir', args.data_dir - if args.nodekey_path and args.nodekey: - raise ValueError("Cannot provide both nodekey_path and nodekey") - elif args.nodekey_path is not None: - yield 'nodekey_path', args.nodekey_path - elif args.nodekey is not None: - yield 'nodekey', decode_hex(args.nodekey) + if args.nodekey is not None: + if os.path.isfile(args.nodekey): + yield 'nodekey_path', args.nodekey + else: + yield 'nodekey', decode_hex(args.nodekey) if args.sync_mode is not None: yield 'sync_mode', args.sync_mode