Skip to content

Commit 599d03f

Browse files
committed
Make sure the fee and cost specified in a NewTransaction match the ones from validating its spend bundle.
1 parent 742a81a commit 599d03f

File tree

6 files changed

+134
-39
lines changed

6 files changed

+134
-39
lines changed

chia/_tests/core/full_node/test_full_node.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3362,18 +3362,25 @@ async def test_pending_tx_cache_retry_on_new_peak(
33623362
@pytest.mark.anyio
33633363
@pytest.mark.parametrize("mismatch_cost", [True, False])
33643364
@pytest.mark.parametrize("mismatch_fee", [True, False])
3365+
@pytest.mark.parametrize("tx_already_seen", [True, False])
3366+
@pytest.mark.parametrize("mismatch_on_reannounce", [True, False])
33653367
async def test_ban_for_mismatched_tx_cost_fee(
33663368
setup_two_nodes_fixture: tuple[list[FullNodeSimulator], list[tuple[WalletNode, ChiaServer]], BlockTools],
33673369
self_hostname: str,
33683370
mismatch_cost: bool,
33693371
mismatch_fee: bool,
3372+
tx_already_seen: bool,
3373+
mismatch_on_reannounce: bool,
33703374
) -> None:
33713375
"""
33723376
Tests that a peer gets banned if it sends a `NewTransaction` message with a
33733377
cost and/or fee that doesn't match the transaction's validation cost/fee.
3374-
We setup two full nodes with the test transaction as already seen, and we
3375-
check its validation cost and fee against the ones specified in the
3376-
`NewTransaction` message.
3378+
We setup two full nodes, and with `tx_already_seen` we control whether the
3379+
first full node has this transaction already or it needs to request it.
3380+
In both cases we check the transaction's validation cost and fee against
3381+
the ones specified in the `NewTransaction` message.
3382+
With `mismatch_on_reannounce` we control whether the peer sent us the same
3383+
transaction twice with different cost and fee.
33773384
"""
33783385
nodes, _, bt = setup_two_nodes_fixture
33793386
full_node_1, full_node_2 = nodes
@@ -3384,17 +3391,26 @@ async def test_ban_for_mismatched_tx_cost_fee(
33843391
ws_con_2 = next(iter(server_2.all_connections.values()))
33853392
wallet = WalletTool(test_constants)
33863393
wallet_ph = wallet.get_new_puzzlehash()
3394+
# If we're covering that the first full node has this transaction already
3395+
# we must add it accordingly, otherwise we'll add it to the second node so
3396+
# that the first node requests it, reacting to the NewTransaction message.
3397+
if tx_already_seen:
3398+
node = full_node_1.full_node
3399+
ws_con = ws_con_1
3400+
else:
3401+
node = full_node_2.full_node
3402+
ws_con = ws_con_2
33873403
blocks = bt.get_consecutive_blocks(
33883404
3, guarantee_transaction_block=True, farmer_reward_puzzle_hash=wallet_ph, pool_reward_puzzle_hash=wallet_ph
33893405
)
33903406
for block in blocks:
3391-
await full_node_1.full_node.add_block(block)
3407+
await node.add_block(block)
33923408
# Create a transaction and add it to the relevant full node's mempool
33933409
coin = blocks[-1].get_included_reward_coins()[0]
33943410
sb = wallet.generate_signed_transaction(uint64(42), wallet_ph, coin)
33953411
sb_name = sb.name()
3396-
await full_node_1.full_node.add_transaction(sb, sb_name, ws_con_1)
3397-
mempool_item = full_node_1.full_node.mempool_manager.get_mempool_item(sb_name)
3412+
await node.add_transaction(sb, sb_name, ws_con)
3413+
mempool_item = node.mempool_manager.get_mempool_item(sb_name)
33983414
assert mempool_item is not None
33993415
# Now send a NewTransaction with a cost and/or fee mismatch from the second
34003416
# full node.
@@ -3407,6 +3423,24 @@ async def test_ban_for_mismatched_tx_cost_fee(
34073423
ws_con_1.peer_info = PeerInfo(full_node_2_ip, ws_con_1.peer_info.port)
34083424
# Send the NewTransaction message from the second node to the first
34093425
await ws_con_2.send_message(msg)
3426+
if mismatch_on_reannounce and (mismatch_cost or mismatch_fee):
3427+
# Send a second NewTransaction that doesn't match the first
3428+
reannounce_cost = uint64(cost + 1) if mismatch_cost else cost
3429+
reannounce_fee = uint64(fee + 1) if mismatch_fee else fee
3430+
reannounce_msg = make_msg(
3431+
ProtocolMessageTypes.new_transaction, NewTransaction(mempool_item.name, reannounce_cost, reannounce_fee)
3432+
)
3433+
await ws_con_2.send_message(reannounce_msg)
3434+
# Make sure the peer is banned as it sent the same transaction twice
3435+
# with different cost and/or fee.
3436+
await time_out_assert(5, lambda: full_node_2_ip in server_1.banned_peers)
3437+
return
3438+
if not tx_already_seen:
3439+
# When the first full node receives the NewTransaction message and it
3440+
# hasn't seen the transaction before, it will issue a transaction
3441+
# request. We need to wait until it receives the transaction and add it
3442+
# to its mempool.
3443+
await time_out_assert(30, lambda: full_node_1.full_node.mempool_manager.seen(mempool_item.name))
34103444
# Make sure the first full node has banned the second as the item it has
34113445
# already seen has a different validation cost and/or fee than the one from
34123446
# the NewTransaction message.

chia/full_node/full_node.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from chia.full_node.mempool_manager import MempoolManager
6262
from chia.full_node.subscriptions import PeerSubscriptions, peers_for_spend_bundle
6363
from chia.full_node.sync_store import Peak, SyncStore
64-
from chia.full_node.tx_processing_queue import TransactionQueue, TransactionQueueEntry
64+
from chia.full_node.tx_processing_queue import PeerWithTx, TransactionQueue, TransactionQueueEntry
6565
from chia.full_node.weight_proof import WeightProofHandler
6666
from chia.protocols import farmer_protocol, full_node_protocol, timelord_protocol, wallet_protocol
6767
from chia.protocols.farmer_protocol import SignagePointSourceData, SPSubSlotSourceData, SPVDFSourceData
@@ -498,7 +498,9 @@ def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
498498
async def _handle_one_transaction(self, entry: TransactionQueueEntry) -> None:
499499
peer = entry.peer
500500
try:
501-
inc_status, err = await self.add_transaction(entry.transaction, entry.spend_name, peer, entry.test)
501+
inc_status, err = await self.add_transaction(
502+
entry.transaction, entry.spend_name, peer, entry.test, entry.peers_with_tx
503+
)
502504
entry.done.set((inc_status, err))
503505
except asyncio.CancelledError:
504506
error_stack = traceback.format_exc()
@@ -2761,7 +2763,14 @@ async def add_end_of_sub_slot(
27612763
return None, False
27622764

27632765
async def add_transaction(
2764-
self, transaction: SpendBundle, spend_name: bytes32, peer: Optional[WSChiaConnection] = None, test: bool = False
2766+
self,
2767+
transaction: SpendBundle,
2768+
spend_name: bytes32,
2769+
peer: Optional[WSChiaConnection] = None,
2770+
test: bool = False,
2771+
# Map of peer ID to its hostname, the fee and the cost it advertised
2772+
# for this transaction.
2773+
peers_with_tx: dict[bytes32, PeerWithTx] = {},
27652774
) -> tuple[MempoolInclusionStatus, Optional[Err]]:
27662775
if self.sync_store.get_sync_mode():
27672776
return MempoolInclusionStatus.FAILED, Err.NO_TRANSACTIONS_WHILE_SYNCING
@@ -2810,10 +2819,25 @@ async def add_transaction(
28102819
f"{self.mempool_manager.mempool.total_mempool_cost() / 5000000}"
28112820
)
28122821

2813-
# Only broadcast successful transactions, not pending ones. Otherwise it's a DOS
2814-
# vector.
28152822
mempool_item = self.mempool_manager.get_mempool_item(spend_name)
28162823
assert mempool_item is not None
2824+
# Now that we validated this transaction, check what fees and
2825+
# costs the peers have advertised for it.
2826+
for peer_id, entry in peers_with_tx.items():
2827+
if entry.advertised_fee == mempool_item.fee and entry.advertised_cost == mempool_item.cost:
2828+
continue
2829+
self.log.warning(
2830+
f"Banning peer {peer_id}. Sent us a new tx {spend_name} with mismatch "
2831+
f"on cost {entry.advertised_cost} vs validation cost {mempool_item.cost} and/or "
2832+
f"fee {entry.advertised_fee} vs {mempool_item.fee}."
2833+
)
2834+
peer = self.server.all_connections.get(peer_id)
2835+
if peer is None:
2836+
self.server.ban_peer(entry.peer_host, CONSENSUS_ERROR_BAN_SECONDS)
2837+
else:
2838+
await peer.close(CONSENSUS_ERROR_BAN_SECONDS)
2839+
# Only broadcast successful transactions, not pending ones. Otherwise it's a DOS
2840+
# vector.
28172841
await self.broadcast_removed_tx(info.removals)
28182842
await self.broadcast_added_tx(mempool_item, current_peer=peer)
28192843

chia/full_node/full_node_api.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import copy
45
import logging
56
import time
67
import traceback
@@ -44,13 +45,13 @@
4445
from chia.full_node.coin_store import CoinStore
4546
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
4647
from chia.full_node.full_block_utils import get_height_and_tx_status_from_block, header_block_from_block
47-
from chia.full_node.tx_processing_queue import TransactionQueueEntry, TransactionQueueFull
48+
from chia.full_node.tx_processing_queue import PeerWithTx, TransactionQueueEntry, TransactionQueueFull
4849
from chia.protocols import farmer_protocol, full_node_protocol, introducer_protocol, timelord_protocol, wallet_protocol
4950
from chia.protocols.fee_estimate import FeeEstimate, FeeEstimateGroup, fee_rate_v2_to_v1
5051
from chia.protocols.full_node_protocol import RejectBlock, RejectBlocks
5152
from chia.protocols.outbound_message import Message, make_msg
5253
from chia.protocols.protocol_message_types import ProtocolMessageTypes
53-
from chia.protocols.protocol_timing import RATE_LIMITER_BAN_SECONDS
54+
from chia.protocols.protocol_timing import CONSENSUS_ERROR_BAN_SECONDS, RATE_LIMITER_BAN_SECONDS
5455
from chia.protocols.shared_protocol import Capability
5556
from chia.protocols.wallet_protocol import (
5657
PuzzleSolutionResponse,
@@ -89,6 +90,10 @@ async def tx_request_and_timeout(full_node: FullNode, transaction_id: bytes32, t
8990
receive it or timeout.
9091
"""
9192
counter = 0
93+
# Make a copy as we'll pop from it here. We keep the original intact as we
94+
# need it in `respond_transaction` when constructing `TransactionQueueEntry`
95+
# to put into `transaction_queue`.
96+
peers_with_tx = copy.copy(full_node.full_node_store.peers_with_tx.get(transaction_id, {}))
9297
try:
9398
while True:
9499
# Limit to asking a few peers, it's possible that this tx got included on chain already
@@ -98,10 +103,9 @@ async def tx_request_and_timeout(full_node: FullNode, transaction_id: bytes32, t
98103
break
99104
if transaction_id not in full_node.full_node_store.peers_with_tx:
100105
break
101-
peers_with_tx: set[bytes32] = full_node.full_node_store.peers_with_tx[transaction_id]
102106
if len(peers_with_tx) == 0:
103107
break
104-
peer_id = peers_with_tx.pop()
108+
peer_id, _ = peers_with_tx.popitem()
105109
assert full_node.server is not None
106110
if peer_id not in full_node.server.all_connections:
107111
continue
@@ -221,28 +225,44 @@ async def new_transaction(
221225
f"with mismatch on cost {transaction.cost} vs validation cost {mempool_item.cost} and/or "
222226
f"fee {transaction.fees} vs {mempool_item.fee}."
223227
)
224-
await peer.close(RATE_LIMITER_BAN_SECONDS)
228+
await peer.close(CONSENSUS_ERROR_BAN_SECONDS)
225229
return None
226230

227231
if self.full_node.mempool_manager.is_fee_enough(transaction.fees, transaction.cost):
228232
# If there's current pending request just add this peer to the set of peers that have this tx
229233
if transaction.transaction_id in self.full_node.full_node_store.pending_tx_request:
230-
if transaction.transaction_id in self.full_node.full_node_store.peers_with_tx:
231-
current_set = self.full_node.full_node_store.peers_with_tx[transaction.transaction_id]
232-
if peer.peer_node_id in current_set:
233-
return None
234-
current_set.add(peer.peer_node_id)
234+
current_map = self.full_node.full_node_store.peers_with_tx.get(transaction.transaction_id)
235+
if current_map is None:
236+
self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = {
237+
peer.peer_node_id: PeerWithTx(
238+
peer_host=peer.peer_info.host,
239+
advertised_fee=transaction.fees,
240+
advertised_cost=transaction.cost,
241+
)
242+
}
235243
return None
236-
else:
237-
new_set = set()
238-
new_set.add(peer.peer_node_id)
239-
self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = new_set
244+
prev = current_map.get(peer.peer_node_id)
245+
if prev is not None:
246+
if prev.advertised_fee != transaction.fees or prev.advertised_cost != transaction.cost:
247+
self.log.warning(
248+
f"Banning peer {peer.peer_node_id}. Sent us a new tx {transaction.transaction_id} with "
249+
f"mismatch on cost {transaction.cost} vs previous advertised cost {prev.advertised_cost} "
250+
f"and/or fee {transaction.fees} vs previous advertised fee {prev.advertised_fee}."
251+
)
252+
await peer.close(CONSENSUS_ERROR_BAN_SECONDS)
240253
return None
254+
current_map[peer.peer_node_id] = PeerWithTx(
255+
peer_host=peer.peer_info.host, advertised_fee=transaction.fees, advertised_cost=transaction.cost
256+
)
257+
return None
241258

242259
self.full_node.full_node_store.pending_tx_request[transaction.transaction_id] = peer.peer_node_id
243-
new_set = set()
244-
new_set.add(peer.peer_node_id)
245-
self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = new_set
260+
self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = {
261+
peer.peer_node_id: PeerWithTx(
262+
peer_host=peer.peer_info.host, advertised_fee=transaction.fees, advertised_cost=transaction.cost
263+
)
264+
}
265+
246266
task_id: bytes32 = bytes32.secret()
247267
fetch_task = create_referenced_task(
248268
tx_request_and_timeout(self.full_node, transaction.transaction_id, task_id)
@@ -282,13 +302,15 @@ async def respond_transaction(
282302
spend_name = std_hash(tx_bytes)
283303
if spend_name in self.full_node.full_node_store.pending_tx_request:
284304
self.full_node.full_node_store.pending_tx_request.pop(spend_name)
305+
peers_with_tx = {}
285306
if spend_name in self.full_node.full_node_store.peers_with_tx:
286-
self.full_node.full_node_store.peers_with_tx.pop(spend_name)
307+
peers_with_tx = self.full_node.full_node_store.peers_with_tx.pop(spend_name)
287308

288309
# TODO: Use fee in priority calculation, to prioritize high fee TXs
289310
try:
290311
await self.full_node.transaction_queue.put(
291-
TransactionQueueEntry(tx.transaction, tx_bytes, spend_name, peer, test), peer.peer_node_id
312+
TransactionQueueEntry(tx.transaction, tx_bytes, spend_name, peer, test, peers_with_tx),
313+
peer.peer_node_id,
292314
)
293315
except TransactionQueueFull:
294316
pass # we can't do anything here, the tx will be dropped. We might do something in the future.

chia/full_node/full_node_store.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from chia.consensus.multiprocess_validation import PreValidationResult
1717
from chia.consensus.pot_iterations import calculate_sp_interval_iters
1818
from chia.consensus.signage_point import SignagePoint
19+
from chia.full_node.tx_processing_queue import PeerWithTx
1920
from chia.protocols import timelord_protocol
2021
from chia.protocols.outbound_message import Message
2122
from chia.types.blockchain_format.classgroup import ClassgroupElement
@@ -132,7 +133,9 @@ class FullNodeStore:
132133
recent_eos: LRUCache[bytes32, tuple[EndOfSubSlotBundle, float]]
133134

134135
pending_tx_request: dict[bytes32, bytes32] # tx_id: peer_id
135-
peers_with_tx: dict[bytes32, set[bytes32]] # tx_id: set[peer_ids}
136+
# Map of transaction ID to the map of peer ID to its hostname, fee and cost
137+
# it advertised for that transaction.
138+
peers_with_tx: dict[bytes32, dict[bytes32, PeerWithTx]]
136139
tx_fetch_tasks: dict[bytes32, asyncio.Task[None]] # Task id: task
137140
serialized_wp_message: Optional[Message]
138141
serialized_wp_message_tip: Optional[bytes32]

chia/full_node/tx_processing_queue.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from chia_rs import SpendBundle
1111
from chia_rs.sized_bytes import bytes32
12+
from chia_rs.sized_ints import uint64
1213

1314
from chia.server.ws_connection import WSChiaConnection
1415
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
@@ -45,6 +46,13 @@ async def wait(self) -> T:
4546
return self._value
4647

4748

49+
@dataclasses.dataclass(frozen=True)
50+
class PeerWithTx:
51+
peer_host: str
52+
advertised_fee: uint64
53+
advertised_cost: uint64
54+
55+
4856
@dataclass(frozen=True)
4957
class TransactionQueueEntry:
5058
"""
@@ -56,6 +64,9 @@ class TransactionQueueEntry:
5664
spend_name: bytes32
5765
peer: Optional[WSChiaConnection] = field(compare=False)
5866
test: bool = field(compare=False)
67+
# IDs of peers that advertised this transaction via new_transaction, along
68+
# with their hostname, fee and cost.
69+
peers_with_tx: dict[bytes32, PeerWithTx] = field(default_factory=dict, compare=False)
5970
done: ValuedEvent[tuple[MempoolInclusionStatus, Optional[Err]]] = field(
6071
default_factory=ValuedEvent,
6172
compare=False,

chia/server/server.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -550,14 +550,7 @@ async def connection_closed(
550550
)
551551
ban_time = 0
552552
if ban_time > 0:
553-
ban_until: float = time.time() + ban_time
554-
self.log.warning(f"Banning {connection.peer_info.host} for {ban_time} seconds")
555-
if connection.peer_info.host in self.banned_peers:
556-
self.banned_peers[connection.peer_info.host] = max(
557-
ban_until, self.banned_peers[connection.peer_info.host]
558-
)
559-
else:
560-
self.banned_peers[connection.peer_info.host] = ban_until
553+
self.ban_peer(connection.peer_info.host, ban_time)
561554

562555
present_connection = self.all_connections.get(connection.peer_node_id)
563556
if present_connection is connection:
@@ -737,3 +730,11 @@ def is_trusted_peer(self, peer: WSChiaConnection, trusted_peers: dict[str, Any])
737730

738731
def set_capabilities(self, capabilities: list[tuple[uint16, str]]) -> None:
739732
self._local_capabilities_for_handshake = capabilities
733+
734+
def ban_peer(self, host: str, ban_time: int) -> None:
735+
ban_until: float = time.time() + ban_time
736+
self.log.warning(f"Banning {host} for {ban_time} seconds")
737+
if host in self.banned_peers:
738+
self.banned_peers[host] = max(ban_until, self.banned_peers[host])
739+
else:
740+
self.banned_peers[host] = ban_until

0 commit comments

Comments
 (0)