Skip to content

Commit f28f786

Browse files
authored
refactor(plugins/execute): Refactor account refunds logic (#1204)
* refactor(execute): Move refund logic to `pre` fixture * fix(plugins/execute): Overwrite tx instead of bumping nonce * fix(rpc): handle null transaction response * fix(execute): Return `--sender-funding-txs-gas-price`
1 parent e67d774 commit f28f786

File tree

6 files changed

+72
-53
lines changed

6 files changed

+72
-53
lines changed

src/cli/gentest/request_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(self):
3636
def eth_get_transaction_by_hash(self, transaction_hash: Hash) -> TransactionByHashResponse:
3737
"""Get transaction data."""
3838
res = self.rpc.get_transaction_by_hash(transaction_hash)
39+
assert res is not None, "Transaction not found"
3940
block_number = res.block_number
4041
assert block_number is not None, "Transaction does not seem to be included in any block"
4142

src/ethereum_test_rpc/rpc.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,13 @@ def get_transaction_count(
136136
block = hex(block_number) if isinstance(block_number, int) else block_number
137137
return int(self.post_request("getTransactionCount", f"{address}", block), 16)
138138

139-
def get_transaction_by_hash(self, transaction_hash: Hash) -> TransactionByHashResponse:
139+
def get_transaction_by_hash(self, transaction_hash: Hash) -> TransactionByHashResponse | None:
140140
"""`eth_getTransactionByHash`: Returns transaction details."""
141141
try:
142-
resp = TransactionByHashResponse(
143-
**self.post_request("getTransactionByHash", f"{transaction_hash}")
144-
)
145-
return resp
142+
response = self.post_request("getTransactionByHash", f"{transaction_hash}")
143+
if response is None:
144+
return None
145+
return TransactionByHashResponse(**response)
146146
except ValidationError as e:
147147
pprint(e.errors())
148148
raise e
@@ -200,7 +200,7 @@ def wait_for_transaction(self, transaction: Transaction) -> TransactionByHashRes
200200
start_time = time.time()
201201
while True:
202202
tx = self.get_transaction_by_hash(tx_hash)
203-
if tx.block_number is not None:
203+
if tx is not None and tx.block_number is not None:
204204
return tx
205205
if (time.time() - start_time) > self.transaction_wait_timeout:
206206
break
@@ -225,7 +225,7 @@ def wait_for_transactions(
225225
while i < len(tx_hashes):
226226
tx_hash = tx_hashes[i]
227227
tx = self.get_transaction_by_hash(tx_hash)
228-
if tx.block_number is not None:
228+
if tx is not None and tx.block_number is not None:
229229
responses.append(tx)
230230
tx_hashes.pop(i)
231231
else:

src/pytest_plugins/execute/execute.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
import pytest
88
from pytest_metadata.plugin import metadata_key # type: ignore
99

10-
from ethereum_test_base_types import Number
1110
from ethereum_test_execution import BaseExecute
1211
from ethereum_test_forks import Fork
1312
from ethereum_test_rpc import EthRPC
14-
from ethereum_test_tools import SPEC_TYPES, BaseTest, TestInfo, Transaction
13+
from ethereum_test_tools import SPEC_TYPES, BaseTest, TestInfo
1514
from ethereum_test_types import TransactionDefaults
1615
from pytest_plugins.spec_version_checker.spec_version_checker import EIPSpecTestItem
1716

@@ -173,7 +172,9 @@ def pytest_html_report_title(report):
173172
@pytest.fixture(scope="session")
174173
def default_gas_price(request) -> int:
175174
"""Return default gas price used for transactions."""
176-
return request.config.getoption("default_gas_price")
175+
gas_price = request.config.getoption("default_gas_price")
176+
assert gas_price > 0, "Gas price must be greater than 0"
177+
return gas_price
177178

178179

179180
@pytest.fixture(scope="session")
@@ -252,7 +253,6 @@ def base_test_parametrizer_func(
252253
eips: List[int],
253254
eth_rpc: EthRPC,
254255
collector: Collector,
255-
default_gas_price: int,
256256
):
257257
"""
258258
Fixture used to instantiate an auto-fillable BaseTest object from within
@@ -281,13 +281,14 @@ def __init__(self, *args, **kwargs):
281281

282282
# wait for pre-requisite transactions to be included in blocks
283283
pre.wait_for_transactions()
284-
for deployed_contract, deployed_code in pre._deployed_contracts:
285-
if eth_rpc.get_code(deployed_contract) == deployed_code:
286-
pass
287-
else:
284+
for deployed_contract, expected_code in pre._deployed_contracts:
285+
actual_code = eth_rpc.get_code(deployed_contract)
286+
if actual_code != expected_code:
288287
raise Exception(
289288
f"Deployed test contract didn't match expected code at address "
290-
f"{deployed_contract} (not enough gas_limit?)."
289+
f"{deployed_contract} (not enough gas_limit?).\n"
290+
f"Expected: {expected_code}\n"
291+
f"Actual: {actual_code}"
291292
)
292293
request.node.config.funded_accounts = ", ".join(
293294
[str(eoa) for eoa in pre._funded_eoa]
@@ -297,33 +298,7 @@ def __init__(self, *args, **kwargs):
297298
execute.execute(eth_rpc)
298299
collector.collect(request.node.nodeid, execute)
299300

300-
sender_start_balance = eth_rpc.get_balance(pre._sender)
301-
302-
yield BaseTestWrapper
303-
304-
# Refund all EOAs (regardless of whether the test passed or failed)
305-
refund_txs = []
306-
for eoa in pre._funded_eoa:
307-
remaining_balance = eth_rpc.get_balance(eoa)
308-
eoa.nonce = Number(eth_rpc.get_transaction_count(eoa))
309-
refund_gas_limit = 21_000
310-
tx_cost = refund_gas_limit * default_gas_price
311-
if remaining_balance < tx_cost:
312-
continue
313-
refund_txs.append(
314-
Transaction(
315-
sender=eoa,
316-
to=pre._sender,
317-
gas_limit=21_000,
318-
gas_price=default_gas_price,
319-
value=remaining_balance - tx_cost,
320-
).with_signature_and_sender()
321-
)
322-
eth_rpc.send_wait_transactions(refund_txs)
323-
324-
sender_end_balance = eth_rpc.get_balance(pre._sender)
325-
used_balance = sender_start_balance - sender_end_balance
326-
print(f"Used balance={used_balance / 10**18:.18f}")
301+
return BaseTestWrapper
327302

328303
return base_test_parametrizer_func
329304

src/pytest_plugins/execute/pre_alloc.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from itertools import count
44
from random import randint
5-
from typing import Iterator, List, Literal, Tuple
5+
from typing import Generator, Iterator, List, Literal, Tuple
66

77
import pytest
88
from pydantic import PrivateAttr
99

10-
from ethereum_test_base_types import Number, StorageRootType, ZeroPaddedHexNumber
10+
from ethereum_test_base_types import Bytes, Number, StorageRootType, ZeroPaddedHexNumber
1111
from ethereum_test_base_types.conversions import (
1212
BytesConvertible,
1313
FixedSizeBytesConvertible,
@@ -94,7 +94,7 @@ class Alloc(BaseAlloc):
9494
_sender: EOA = PrivateAttr()
9595
_eth_rpc: EthRPC = PrivateAttr()
9696
_txs: List[Transaction] = PrivateAttr(default_factory=list)
97-
_deployed_contracts: List[Tuple[Address, bytes]] = PrivateAttr(default_factory=list)
97+
_deployed_contracts: List[Tuple[Address, Bytes]] = PrivateAttr(default_factory=list)
9898
_funded_eoa: List[EOA] = PrivateAttr(default_factory=list)
9999
_evm_code_type: EVMCodeType | None = PrivateAttr(None)
100100
_chain_id: int = PrivateAttr()
@@ -206,7 +206,7 @@ def deploy_contract(
206206
self._txs.append(deploy_tx)
207207

208208
contract_address = deploy_tx.created_contract
209-
self._deployed_contracts.append((contract_address, bytes(code)))
209+
self._deployed_contracts.append((contract_address, Bytes(code)))
210210

211211
assert Number(nonce) >= 1, "impossible to deploy contract with nonce lower than one"
212212

@@ -373,9 +373,14 @@ def pre(
373373
evm_code_type: EVMCodeType,
374374
chain_id: int,
375375
eoa_fund_amount_default: int,
376-
) -> Alloc:
376+
default_gas_price: int,
377+
) -> Generator[Alloc, None, None]:
377378
"""Return default pre allocation for all tests (Empty alloc)."""
378-
return Alloc(
379+
# Record the starting balance of the sender
380+
sender_test_starting_balance = eth_rpc.get_balance(sender_key)
381+
382+
# Prepare the pre-alloc
383+
pre = Alloc(
379384
fork=fork,
380385
sender=sender_key,
381386
eth_rpc=eth_rpc,
@@ -384,3 +389,31 @@ def pre(
384389
chain_id=chain_id,
385390
eoa_fund_amount_default=eoa_fund_amount_default,
386391
)
392+
393+
# Yield the pre-alloc for usage during the test
394+
yield pre
395+
396+
# Refund all EOAs (regardless of whether the test passed or failed)
397+
refund_txs = []
398+
for eoa in pre._funded_eoa:
399+
remaining_balance = eth_rpc.get_balance(eoa)
400+
eoa.nonce = Number(eth_rpc.get_transaction_count(eoa))
401+
refund_gas_limit = 21_000
402+
tx_cost = refund_gas_limit * default_gas_price
403+
if remaining_balance < tx_cost:
404+
continue
405+
refund_txs.append(
406+
Transaction(
407+
sender=eoa,
408+
to=sender_key,
409+
gas_limit=21_000,
410+
gas_price=default_gas_price,
411+
value=remaining_balance - tx_cost,
412+
).with_signature_and_sender()
413+
)
414+
eth_rpc.send_wait_transactions(refund_txs)
415+
416+
# Record the ending balance of the sender
417+
sender_test_ending_balance = eth_rpc.get_balance(sender_key)
418+
used_balance = sender_test_starting_balance - sender_test_ending_balance
419+
print(f"Used balance={used_balance / 10**18:.18f}")

src/pytest_plugins/execute/rpc/hive.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ def wait_for_transactions(
721721
while tx_id < len(tx_hashes):
722722
tx_hash = tx_hashes[tx_id]
723723
tx = self.get_transaction_by_hash(tx_hash)
724+
assert tx is not None, f"Transaction {tx_hash} not found"
724725
if tx.block_number is not None:
725726
responses.append(tx)
726727
tx_hashes.pop(tx_id)

src/pytest_plugins/execute/sender.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def pytest_addoption(parser):
3434
action="store",
3535
dest="sender_funding_transactions_gas_price",
3636
type=Wei,
37-
default=10**9,
37+
default=None,
3838
help=("Gas price set for the funding transactions of each worker's sender key."),
3939
)
4040

@@ -49,9 +49,15 @@ def pytest_addoption(parser):
4949

5050

5151
@pytest.fixture(scope="session")
52-
def sender_funding_transactions_gas_price(request: pytest.FixtureRequest) -> int:
52+
def sender_funding_transactions_gas_price(
53+
request: pytest.FixtureRequest, default_gas_price: int
54+
) -> int:
5355
"""Get the gas price for the funding transactions."""
54-
return request.config.option.sender_funding_transactions_gas_price
56+
gas_price: int | None = request.config.option.sender_funding_transactions_gas_price
57+
if gas_price is None:
58+
gas_price = default_gas_price
59+
assert gas_price > 0, "Gas price must be greater than 0"
60+
return gas_price
5561

5662

5763
@pytest.fixture(scope="session")
@@ -159,13 +165,16 @@ def sender_key(
159165

160166
# refund seed sender
161167
remaining_balance = eth_rpc.get_balance(sender)
168+
sender.nonce = Number(eth_rpc.get_transaction_count(sender))
162169
used_balance = sender_key_initial_balance - remaining_balance
163170
request.config.stash[metadata_key]["Senders"][str(sender)] = (
164171
f"Used balance={used_balance / 10**18:.18f}"
165172
)
166173

167174
refund_gas_limit = sender_fund_refund_gas_limit
168-
refund_gas_price = sender_funding_transactions_gas_price
175+
# double the gas price to ensure the transaction is included and overwrites any other
176+
# transaction that might have been sent by the sender.
177+
refund_gas_price = sender_funding_transactions_gas_price * 2
169178
tx_cost = refund_gas_limit * refund_gas_price
170179

171180
if (remaining_balance - 1) < tx_cost:

0 commit comments

Comments
 (0)