diff --git a/chia/_tests/core/full_node/test_full_node.py b/chia/_tests/core/full_node/test_full_node.py index 1f800ebbd2c1..b7a007b83c0f 100644 --- a/chia/_tests/core/full_node/test_full_node.py +++ b/chia/_tests/core/full_node/test_full_node.py @@ -3362,39 +3362,58 @@ async def test_pending_tx_cache_retry_on_new_peak( @pytest.mark.anyio @pytest.mark.parametrize("mismatch_cost", [True, False]) @pytest.mark.parametrize("mismatch_fee", [True, False]) +@pytest.mark.parametrize("tx_already_seen", [True, False]) +@pytest.mark.parametrize("mismatch_on_reannounce", [True, False]) async def test_ban_for_mismatched_tx_cost_fee( - setup_two_nodes_fixture: tuple[list[FullNodeSimulator], list[tuple[WalletNode, ChiaServer]], BlockTools], + three_nodes: list[FullNodeAPI], + bt: BlockTools, self_hostname: str, mismatch_cost: bool, mismatch_fee: bool, + tx_already_seen: bool, + mismatch_on_reannounce: bool, ) -> None: """ Tests that a peer gets banned if it sends a `NewTransaction` message with a cost and/or fee that doesn't match the transaction's validation cost/fee. - We setup two full nodes with the test transaction as already seen, and we - check its validation cost and fee against the ones specified in the - `NewTransaction` message. + We setup full nodes, and with `tx_already_seen` we control whether the + first full node has this transaction already or it needs to request it. + In both cases we check the transaction's validation cost and fee against + the ones specified in the `NewTransaction` message. + With `mismatch_on_reannounce` we control whether the peer sent us the same + transaction twice with different cost and fee. """ - nodes, _, bt = setup_two_nodes_fixture - full_node_1, full_node_2 = nodes + full_node_1, full_node_2, full_node_3 = three_nodes server_1 = full_node_1.full_node.server server_2 = full_node_2.full_node.server + server_3 = full_node_3.full_node.server await server_2.start_client(PeerInfo(self_hostname, server_1.get_port()), full_node_2.full_node.on_connect) + await server_3.start_client(PeerInfo(self_hostname, server_1.get_port()), full_node_3.full_node.on_connect) ws_con_1 = next(iter(server_1.all_connections.values())) ws_con_2 = next(iter(server_2.all_connections.values())) + ws_con_3 = next(iter(server_3.all_connections.values())) wallet = WalletTool(test_constants) wallet_ph = wallet.get_new_puzzlehash() + # If we're covering that the first full node has this transaction already + # we must add it accordingly, otherwise we'll add it to the second node so + # that the first node requests it, reacting to the NewTransaction message. + if tx_already_seen: + node = full_node_1.full_node + ws_con = ws_con_1 + else: + node = full_node_2.full_node + ws_con = ws_con_2 blocks = bt.get_consecutive_blocks( 3, guarantee_transaction_block=True, farmer_reward_puzzle_hash=wallet_ph, pool_reward_puzzle_hash=wallet_ph ) for block in blocks: - await full_node_1.full_node.add_block(block) + await node.add_block(block) # Create a transaction and add it to the relevant full node's mempool coin = blocks[-1].get_included_reward_coins()[0] sb = wallet.generate_signed_transaction(uint64(42), wallet_ph, coin) sb_name = sb.name() - await full_node_1.full_node.add_transaction(sb, sb_name, ws_con_1) - mempool_item = full_node_1.full_node.mempool_manager.get_mempool_item(sb_name) + await node.add_transaction(sb, sb_name, ws_con) + mempool_item = node.mempool_manager.get_mempool_item(sb_name) assert mempool_item is not None # Now send a NewTransaction with a cost and/or fee mismatch from the second # full node. @@ -3405,8 +3424,35 @@ async def test_ban_for_mismatched_tx_cost_fee( # second node. full_node_2_ip = "1.3.3.7" ws_con_1.peer_info = PeerInfo(full_node_2_ip, ws_con_1.peer_info.port) + # Send the NewTransaction message from the second node to the first - await ws_con_2.send_message(msg) + async def send_from_node_2() -> None: + await ws_con_2.send_message(msg) + + # Send this message from the third node as well, just to end up with two + # peers advertising the same transaction at the same time. + async def send_from_node_3() -> None: + await ws_con_3.send_message(msg) + + await asyncio.gather(send_from_node_2(), send_from_node_3()) + if mismatch_on_reannounce and (mismatch_cost or mismatch_fee): + # Send a second NewTransaction that doesn't match the first + reannounce_cost = uint64(cost + 1) if mismatch_cost else cost + reannounce_fee = uint64(fee + 1) if mismatch_fee else fee + reannounce_msg = make_msg( + ProtocolMessageTypes.new_transaction, NewTransaction(mempool_item.name, reannounce_cost, reannounce_fee) + ) + await ws_con_2.send_message(reannounce_msg) + # Make sure the peer is banned as it sent the same transaction twice + # with different cost and/or fee. + await time_out_assert(5, lambda: full_node_2_ip in server_1.banned_peers) + return + if not tx_already_seen: + # When the first full node receives the NewTransaction message and it + # hasn't seen the transaction before, it will issue a transaction + # request. We need to wait until it receives the transaction and add it + # to its mempool. + await time_out_assert(30, lambda: full_node_1.full_node.mempool_manager.seen(mempool_item.name)) # Make sure the first full node has banned the second as the item it has # already seen has a different validation cost and/or fee than the one from # the NewTransaction message. diff --git a/chia/full_node/full_node.py b/chia/full_node/full_node.py index f935b6f95c51..756d3b1a4933 100644 --- a/chia/full_node/full_node.py +++ b/chia/full_node/full_node.py @@ -61,7 +61,7 @@ from chia.full_node.mempool_manager import MempoolManager from chia.full_node.subscriptions import PeerSubscriptions, peers_for_spend_bundle from chia.full_node.sync_store import Peak, SyncStore -from chia.full_node.tx_processing_queue import TransactionQueue, TransactionQueueEntry +from chia.full_node.tx_processing_queue import PeerWithTx, TransactionQueue, TransactionQueueEntry from chia.full_node.weight_proof import WeightProofHandler from chia.protocols import farmer_protocol, full_node_protocol, timelord_protocol, wallet_protocol from chia.protocols.farmer_protocol import SignagePointSourceData, SPSubSlotSourceData, SPVDFSourceData @@ -498,7 +498,9 @@ def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None: async def _handle_one_transaction(self, entry: TransactionQueueEntry) -> None: peer = entry.peer try: - inc_status, err = await self.add_transaction(entry.transaction, entry.spend_name, peer, entry.test) + inc_status, err = await self.add_transaction( + entry.transaction, entry.spend_name, peer, entry.test, entry.peers_with_tx + ) entry.done.set((inc_status, err)) except asyncio.CancelledError: error_stack = traceback.format_exc() @@ -2761,7 +2763,14 @@ async def add_end_of_sub_slot( return None, False async def add_transaction( - self, transaction: SpendBundle, spend_name: bytes32, peer: Optional[WSChiaConnection] = None, test: bool = False + self, + transaction: SpendBundle, + spend_name: bytes32, + peer: Optional[WSChiaConnection] = None, + test: bool = False, + # Map of peer ID to its hostname, the fee and the cost it advertised + # for this transaction. + peers_with_tx: dict[bytes32, PeerWithTx] = {}, ) -> tuple[MempoolInclusionStatus, Optional[Err]]: if self.sync_store.get_sync_mode(): return MempoolInclusionStatus.FAILED, Err.NO_TRANSACTIONS_WHILE_SYNCING @@ -2810,10 +2819,25 @@ async def add_transaction( f"{self.mempool_manager.mempool.total_mempool_cost() / 5000000}" ) - # Only broadcast successful transactions, not pending ones. Otherwise it's a DOS - # vector. mempool_item = self.mempool_manager.get_mempool_item(spend_name) assert mempool_item is not None + # Now that we validated this transaction, check what fees and + # costs the peers have advertised for it. + for peer_id, entry in peers_with_tx.items(): + if entry.advertised_fee == mempool_item.fee and entry.advertised_cost == mempool_item.cost: + continue + self.log.warning( + f"Banning peer {peer_id}. Sent us a new tx {spend_name} with mismatch " + f"on cost {entry.advertised_cost} vs validation cost {mempool_item.cost} and/or " + f"fee {entry.advertised_fee} vs {mempool_item.fee}." + ) + peer = self.server.all_connections.get(peer_id) + if peer is None: + self.server.ban_peer(entry.peer_host, CONSENSUS_ERROR_BAN_SECONDS) + else: + await peer.close(CONSENSUS_ERROR_BAN_SECONDS) + # Only broadcast successful transactions, not pending ones. Otherwise it's a DOS + # vector. await self.broadcast_removed_tx(info.removals) await self.broadcast_added_tx(mempool_item, current_peer=peer) diff --git a/chia/full_node/full_node_api.py b/chia/full_node/full_node_api.py index ed88d2d32c0a..ae760f4ea333 100644 --- a/chia/full_node/full_node_api.py +++ b/chia/full_node/full_node_api.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import copy import logging import time import traceback @@ -44,13 +45,13 @@ from chia.full_node.coin_store import CoinStore from chia.full_node.fee_estimator_interface import FeeEstimatorInterface from chia.full_node.full_block_utils import get_height_and_tx_status_from_block, header_block_from_block -from chia.full_node.tx_processing_queue import TransactionQueueEntry, TransactionQueueFull +from chia.full_node.tx_processing_queue import PeerWithTx, TransactionQueueEntry, TransactionQueueFull from chia.protocols import farmer_protocol, full_node_protocol, introducer_protocol, timelord_protocol, wallet_protocol from chia.protocols.fee_estimate import FeeEstimate, FeeEstimateGroup, fee_rate_v2_to_v1 from chia.protocols.full_node_protocol import RejectBlock, RejectBlocks from chia.protocols.outbound_message import Message, make_msg from chia.protocols.protocol_message_types import ProtocolMessageTypes -from chia.protocols.protocol_timing import RATE_LIMITER_BAN_SECONDS +from chia.protocols.protocol_timing import CONSENSUS_ERROR_BAN_SECONDS, RATE_LIMITER_BAN_SECONDS from chia.protocols.shared_protocol import Capability from chia.protocols.wallet_protocol import ( PuzzleSolutionResponse, @@ -89,6 +90,10 @@ async def tx_request_and_timeout(full_node: FullNode, transaction_id: bytes32, t receive it or timeout. """ counter = 0 + # Make a copy as we'll pop from it here. We keep the original intact as we + # need it in `respond_transaction` when constructing `TransactionQueueEntry` + # to put into `transaction_queue`. + peers_with_tx = copy.copy(full_node.full_node_store.peers_with_tx.get(transaction_id, {})) try: while True: # Limit to asking a few peers, it's possible that this tx got included on chain already @@ -98,10 +103,9 @@ async def tx_request_and_timeout(full_node: FullNode, transaction_id: bytes32, t break if transaction_id not in full_node.full_node_store.peers_with_tx: break - peers_with_tx: set[bytes32] = full_node.full_node_store.peers_with_tx[transaction_id] if len(peers_with_tx) == 0: break - peer_id = peers_with_tx.pop() + peer_id, _ = peers_with_tx.popitem() assert full_node.server is not None if peer_id not in full_node.server.all_connections: continue @@ -221,28 +225,44 @@ async def new_transaction( f"with mismatch on cost {transaction.cost} vs validation cost {mempool_item.cost} and/or " f"fee {transaction.fees} vs {mempool_item.fee}." ) - await peer.close(RATE_LIMITER_BAN_SECONDS) + await peer.close(CONSENSUS_ERROR_BAN_SECONDS) return None if self.full_node.mempool_manager.is_fee_enough(transaction.fees, transaction.cost): # If there's current pending request just add this peer to the set of peers that have this tx if transaction.transaction_id in self.full_node.full_node_store.pending_tx_request: - if transaction.transaction_id in self.full_node.full_node_store.peers_with_tx: - current_set = self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] - if peer.peer_node_id in current_set: - return None - current_set.add(peer.peer_node_id) + current_map = self.full_node.full_node_store.peers_with_tx.get(transaction.transaction_id) + if current_map is None: + self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = { + peer.peer_node_id: PeerWithTx( + peer_host=peer.peer_info.host, + advertised_fee=transaction.fees, + advertised_cost=transaction.cost, + ) + } return None - else: - new_set = set() - new_set.add(peer.peer_node_id) - self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = new_set + prev = current_map.get(peer.peer_node_id) + if prev is not None: + if prev.advertised_fee != transaction.fees or prev.advertised_cost != transaction.cost: + self.log.warning( + f"Banning peer {peer.peer_node_id}. Sent us a new tx {transaction.transaction_id} with " + f"mismatch on cost {transaction.cost} vs previous advertised cost {prev.advertised_cost} " + f"and/or fee {transaction.fees} vs previous advertised fee {prev.advertised_fee}." + ) + await peer.close(CONSENSUS_ERROR_BAN_SECONDS) return None + current_map[peer.peer_node_id] = PeerWithTx( + peer_host=peer.peer_info.host, advertised_fee=transaction.fees, advertised_cost=transaction.cost + ) + return None self.full_node.full_node_store.pending_tx_request[transaction.transaction_id] = peer.peer_node_id - new_set = set() - new_set.add(peer.peer_node_id) - self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = new_set + self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = { + peer.peer_node_id: PeerWithTx( + peer_host=peer.peer_info.host, advertised_fee=transaction.fees, advertised_cost=transaction.cost + ) + } + task_id: bytes32 = bytes32.secret() fetch_task = create_referenced_task( tx_request_and_timeout(self.full_node, transaction.transaction_id, task_id) @@ -282,13 +302,15 @@ async def respond_transaction( spend_name = std_hash(tx_bytes) if spend_name in self.full_node.full_node_store.pending_tx_request: self.full_node.full_node_store.pending_tx_request.pop(spend_name) + peers_with_tx = {} if spend_name in self.full_node.full_node_store.peers_with_tx: - self.full_node.full_node_store.peers_with_tx.pop(spend_name) + peers_with_tx = self.full_node.full_node_store.peers_with_tx.pop(spend_name) # TODO: Use fee in priority calculation, to prioritize high fee TXs try: await self.full_node.transaction_queue.put( - TransactionQueueEntry(tx.transaction, tx_bytes, spend_name, peer, test), peer.peer_node_id + TransactionQueueEntry(tx.transaction, tx_bytes, spend_name, peer, test, peers_with_tx), + peer.peer_node_id, ) except TransactionQueueFull: pass # we can't do anything here, the tx will be dropped. We might do something in the future. diff --git a/chia/full_node/full_node_store.py b/chia/full_node/full_node_store.py index b21abc6cd36e..3cb64bba5703 100644 --- a/chia/full_node/full_node_store.py +++ b/chia/full_node/full_node_store.py @@ -16,6 +16,7 @@ from chia.consensus.multiprocess_validation import PreValidationResult from chia.consensus.pot_iterations import calculate_sp_interval_iters from chia.consensus.signage_point import SignagePoint +from chia.full_node.tx_processing_queue import PeerWithTx from chia.protocols import timelord_protocol from chia.protocols.outbound_message import Message from chia.types.blockchain_format.classgroup import ClassgroupElement @@ -132,7 +133,9 @@ class FullNodeStore: recent_eos: LRUCache[bytes32, tuple[EndOfSubSlotBundle, float]] pending_tx_request: dict[bytes32, bytes32] # tx_id: peer_id - peers_with_tx: dict[bytes32, set[bytes32]] # tx_id: set[peer_ids} + # Map of transaction ID to the map of peer ID to its hostname, fee and cost + # it advertised for that transaction. + peers_with_tx: dict[bytes32, dict[bytes32, PeerWithTx]] tx_fetch_tasks: dict[bytes32, asyncio.Task[None]] # Task id: task serialized_wp_message: Optional[Message] serialized_wp_message_tip: Optional[bytes32] diff --git a/chia/full_node/tx_processing_queue.py b/chia/full_node/tx_processing_queue.py index cbf17a3d452c..6b262ee7e244 100644 --- a/chia/full_node/tx_processing_queue.py +++ b/chia/full_node/tx_processing_queue.py @@ -9,6 +9,7 @@ from chia_rs import SpendBundle from chia_rs.sized_bytes import bytes32 +from chia_rs.sized_ints import uint64 from chia.server.ws_connection import WSChiaConnection from chia.types.mempool_inclusion_status import MempoolInclusionStatus @@ -45,6 +46,13 @@ async def wait(self) -> T: return self._value +@dataclasses.dataclass(frozen=True) +class PeerWithTx: + peer_host: str + advertised_fee: uint64 + advertised_cost: uint64 + + @dataclass(frozen=True) class TransactionQueueEntry: """ @@ -56,6 +64,9 @@ class TransactionQueueEntry: spend_name: bytes32 peer: Optional[WSChiaConnection] = field(compare=False) test: bool = field(compare=False) + # IDs of peers that advertised this transaction via new_transaction, along + # with their hostname, fee and cost. + peers_with_tx: dict[bytes32, PeerWithTx] = field(default_factory=dict, compare=False) done: ValuedEvent[tuple[MempoolInclusionStatus, Optional[Err]]] = field( default_factory=ValuedEvent, compare=False, diff --git a/chia/server/server.py b/chia/server/server.py index e93d77fe555c..4cef6fa4a3e3 100644 --- a/chia/server/server.py +++ b/chia/server/server.py @@ -550,14 +550,7 @@ async def connection_closed( ) ban_time = 0 if ban_time > 0: - ban_until: float = time.time() + ban_time - self.log.warning(f"Banning {connection.peer_info.host} for {ban_time} seconds") - if connection.peer_info.host in self.banned_peers: - self.banned_peers[connection.peer_info.host] = max( - ban_until, self.banned_peers[connection.peer_info.host] - ) - else: - self.banned_peers[connection.peer_info.host] = ban_until + self.ban_peer(connection.peer_info.host, ban_time) present_connection = self.all_connections.get(connection.peer_node_id) if present_connection is connection: @@ -737,3 +730,11 @@ def is_trusted_peer(self, peer: WSChiaConnection, trusted_peers: dict[str, Any]) def set_capabilities(self, capabilities: list[tuple[uint16, str]]) -> None: self._local_capabilities_for_handshake = capabilities + + def ban_peer(self, host: str, ban_time: int) -> None: + ban_until: float = time.time() + ban_time + self.log.warning(f"Banning {host} for {ban_time} seconds") + if host in self.banned_peers: + self.banned_peers[host] = max(ban_until, self.banned_peers[host]) + else: + self.banned_peers[host] = ban_until