Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions chia/_tests/core/full_node/test_full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3362,39 +3362,58 @@ async def test_pending_tx_cache_retry_on_new_peak(
@pytest.mark.anyio
@pytest.mark.parametrize("mismatch_cost", [True, False])
@pytest.mark.parametrize("mismatch_fee", [True, False])
@pytest.mark.parametrize("tx_already_seen", [True, False])
@pytest.mark.parametrize("mismatch_on_reannounce", [True, False])
async def test_ban_for_mismatched_tx_cost_fee(
setup_two_nodes_fixture: tuple[list[FullNodeSimulator], list[tuple[WalletNode, ChiaServer]], BlockTools],
three_nodes: list[FullNodeAPI],
bt: BlockTools,
self_hostname: str,
mismatch_cost: bool,
mismatch_fee: bool,
tx_already_seen: bool,
mismatch_on_reannounce: bool,
) -> None:
"""
Tests that a peer gets banned if it sends a `NewTransaction` message with a
cost and/or fee that doesn't match the transaction's validation cost/fee.
We setup two full nodes with the test transaction as already seen, and we
check its validation cost and fee against the ones specified in the
`NewTransaction` message.
We setup full nodes, and with `tx_already_seen` we control whether the
first full node has this transaction already or it needs to request it.
In both cases we check the transaction's validation cost and fee against
the ones specified in the `NewTransaction` message.
With `mismatch_on_reannounce` we control whether the peer sent us the same
transaction twice with different cost and fee.
"""
nodes, _, bt = setup_two_nodes_fixture
full_node_1, full_node_2 = nodes
full_node_1, full_node_2, full_node_3 = three_nodes
server_1 = full_node_1.full_node.server
server_2 = full_node_2.full_node.server
server_3 = full_node_3.full_node.server
await server_2.start_client(PeerInfo(self_hostname, server_1.get_port()), full_node_2.full_node.on_connect)
await server_3.start_client(PeerInfo(self_hostname, server_1.get_port()), full_node_3.full_node.on_connect)
ws_con_1 = next(iter(server_1.all_connections.values()))
ws_con_2 = next(iter(server_2.all_connections.values()))
ws_con_3 = next(iter(server_3.all_connections.values()))
wallet = WalletTool(test_constants)
wallet_ph = wallet.get_new_puzzlehash()
# If we're covering that the first full node has this transaction already
# we must add it accordingly, otherwise we'll add it to the second node so
# that the first node requests it, reacting to the NewTransaction message.
if tx_already_seen:
node = full_node_1.full_node
ws_con = ws_con_1
else:
node = full_node_2.full_node
ws_con = ws_con_2
blocks = bt.get_consecutive_blocks(
3, guarantee_transaction_block=True, farmer_reward_puzzle_hash=wallet_ph, pool_reward_puzzle_hash=wallet_ph
)
for block in blocks:
await full_node_1.full_node.add_block(block)
await node.add_block(block)
# Create a transaction and add it to the relevant full node's mempool
coin = blocks[-1].get_included_reward_coins()[0]
sb = wallet.generate_signed_transaction(uint64(42), wallet_ph, coin)
sb_name = sb.name()
await full_node_1.full_node.add_transaction(sb, sb_name, ws_con_1)
mempool_item = full_node_1.full_node.mempool_manager.get_mempool_item(sb_name)
await node.add_transaction(sb, sb_name, ws_con)
mempool_item = node.mempool_manager.get_mempool_item(sb_name)
assert mempool_item is not None
# Now send a NewTransaction with a cost and/or fee mismatch from the second
# full node.
Expand All @@ -3405,8 +3424,35 @@ async def test_ban_for_mismatched_tx_cost_fee(
# second node.
full_node_2_ip = "1.3.3.7"
ws_con_1.peer_info = PeerInfo(full_node_2_ip, ws_con_1.peer_info.port)

# Send the NewTransaction message from the second node to the first
await ws_con_2.send_message(msg)
async def send_from_node_2() -> None:
await ws_con_2.send_message(msg)

# Send this message from the third node as well, just to end up with two
# peers advertising the same transaction at the same time.
async def send_from_node_3() -> None:
await ws_con_3.send_message(msg)

await asyncio.gather(send_from_node_2(), send_from_node_3())
if mismatch_on_reannounce and (mismatch_cost or mismatch_fee):
# Send a second NewTransaction that doesn't match the first
reannounce_cost = uint64(cost + 1) if mismatch_cost else cost
reannounce_fee = uint64(fee + 1) if mismatch_fee else fee
reannounce_msg = make_msg(
ProtocolMessageTypes.new_transaction, NewTransaction(mempool_item.name, reannounce_cost, reannounce_fee)
)
await ws_con_2.send_message(reannounce_msg)
# Make sure the peer is banned as it sent the same transaction twice
# with different cost and/or fee.
await time_out_assert(5, lambda: full_node_2_ip in server_1.banned_peers)
return
if not tx_already_seen:
# When the first full node receives the NewTransaction message and it
# hasn't seen the transaction before, it will issue a transaction
# request. We need to wait until it receives the transaction and add it
# to its mempool.
await time_out_assert(30, lambda: full_node_1.full_node.mempool_manager.seen(mempool_item.name))
# Make sure the first full node has banned the second as the item it has
# already seen has a different validation cost and/or fee than the one from
# the NewTransaction message.
Expand Down
34 changes: 29 additions & 5 deletions chia/full_node/full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from chia.full_node.mempool_manager import MempoolManager
from chia.full_node.subscriptions import PeerSubscriptions, peers_for_spend_bundle
from chia.full_node.sync_store import Peak, SyncStore
from chia.full_node.tx_processing_queue import TransactionQueue, TransactionQueueEntry
from chia.full_node.tx_processing_queue import PeerWithTx, TransactionQueue, TransactionQueueEntry
from chia.full_node.weight_proof import WeightProofHandler
from chia.protocols import farmer_protocol, full_node_protocol, timelord_protocol, wallet_protocol
from chia.protocols.farmer_protocol import SignagePointSourceData, SPSubSlotSourceData, SPVDFSourceData
Expand Down Expand Up @@ -498,7 +498,9 @@ def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
async def _handle_one_transaction(self, entry: TransactionQueueEntry) -> None:
peer = entry.peer
try:
inc_status, err = await self.add_transaction(entry.transaction, entry.spend_name, peer, entry.test)
inc_status, err = await self.add_transaction(
entry.transaction, entry.spend_name, peer, entry.test, entry.peers_with_tx
)
entry.done.set((inc_status, err))
except asyncio.CancelledError:
error_stack = traceback.format_exc()
Expand Down Expand Up @@ -2761,7 +2763,14 @@ async def add_end_of_sub_slot(
return None, False

async def add_transaction(
self, transaction: SpendBundle, spend_name: bytes32, peer: Optional[WSChiaConnection] = None, test: bool = False
self,
transaction: SpendBundle,
spend_name: bytes32,
peer: Optional[WSChiaConnection] = None,
test: bool = False,
# Map of peer ID to its hostname, the fee and the cost it advertised
# for this transaction.
peers_with_tx: dict[bytes32, PeerWithTx] = {},
) -> tuple[MempoolInclusionStatus, Optional[Err]]:
if self.sync_store.get_sync_mode():
return MempoolInclusionStatus.FAILED, Err.NO_TRANSACTIONS_WHILE_SYNCING
Expand Down Expand Up @@ -2810,10 +2819,25 @@ async def add_transaction(
f"{self.mempool_manager.mempool.total_mempool_cost() / 5000000}"
)

# Only broadcast successful transactions, not pending ones. Otherwise it's a DOS
# vector.
mempool_item = self.mempool_manager.get_mempool_item(spend_name)
assert mempool_item is not None
# Now that we validated this transaction, check what fees and
# costs the peers have advertised for it.
for peer_id, entry in peers_with_tx.items():
if entry.advertised_fee == mempool_item.fee and entry.advertised_cost == mempool_item.cost:
continue
self.log.warning(
f"Banning peer {peer_id}. Sent us a new tx {spend_name} with mismatch "
f"on cost {entry.advertised_cost} vs validation cost {mempool_item.cost} and/or "
f"fee {entry.advertised_fee} vs {mempool_item.fee}."
)
peer = self.server.all_connections.get(peer_id)
if peer is None:
self.server.ban_peer(entry.peer_host, CONSENSUS_ERROR_BAN_SECONDS)
else:
await peer.close(CONSENSUS_ERROR_BAN_SECONDS)
# Only broadcast successful transactions, not pending ones. Otherwise it's a DOS
# vector.
await self.broadcast_removed_tx(info.removals)
await self.broadcast_added_tx(mempool_item, current_peer=peer)

Expand Down
60 changes: 41 additions & 19 deletions chia/full_node/full_node_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import copy
import logging
import time
import traceback
Expand Down Expand Up @@ -44,13 +45,13 @@
from chia.full_node.coin_store import CoinStore
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.full_node.full_block_utils import get_height_and_tx_status_from_block, header_block_from_block
from chia.full_node.tx_processing_queue import TransactionQueueEntry, TransactionQueueFull
from chia.full_node.tx_processing_queue import PeerWithTx, TransactionQueueEntry, TransactionQueueFull
from chia.protocols import farmer_protocol, full_node_protocol, introducer_protocol, timelord_protocol, wallet_protocol
from chia.protocols.fee_estimate import FeeEstimate, FeeEstimateGroup, fee_rate_v2_to_v1
from chia.protocols.full_node_protocol import RejectBlock, RejectBlocks
from chia.protocols.outbound_message import Message, make_msg
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.protocol_timing import RATE_LIMITER_BAN_SECONDS
from chia.protocols.protocol_timing import CONSENSUS_ERROR_BAN_SECONDS, RATE_LIMITER_BAN_SECONDS
from chia.protocols.shared_protocol import Capability
from chia.protocols.wallet_protocol import (
PuzzleSolutionResponse,
Expand Down Expand Up @@ -89,6 +90,10 @@ async def tx_request_and_timeout(full_node: FullNode, transaction_id: bytes32, t
receive it or timeout.
"""
counter = 0
# Make a copy as we'll pop from it here. We keep the original intact as we
# need it in `respond_transaction` when constructing `TransactionQueueEntry`
# to put into `transaction_queue`.
peers_with_tx = copy.copy(full_node.full_node_store.peers_with_tx.get(transaction_id, {}))
try:
while True:
# Limit to asking a few peers, it's possible that this tx got included on chain already
Expand All @@ -98,10 +103,9 @@ async def tx_request_and_timeout(full_node: FullNode, transaction_id: bytes32, t
break
if transaction_id not in full_node.full_node_store.peers_with_tx:
break
peers_with_tx: set[bytes32] = full_node.full_node_store.peers_with_tx[transaction_id]
if len(peers_with_tx) == 0:
break
peer_id = peers_with_tx.pop()
peer_id, _ = peers_with_tx.popitem()
assert full_node.server is not None
if peer_id not in full_node.server.all_connections:
continue
Expand Down Expand Up @@ -221,28 +225,44 @@ async def new_transaction(
f"with mismatch on cost {transaction.cost} vs validation cost {mempool_item.cost} and/or "
f"fee {transaction.fees} vs {mempool_item.fee}."
)
await peer.close(RATE_LIMITER_BAN_SECONDS)
await peer.close(CONSENSUS_ERROR_BAN_SECONDS)
return None

if self.full_node.mempool_manager.is_fee_enough(transaction.fees, transaction.cost):
# If there's current pending request just add this peer to the set of peers that have this tx
if transaction.transaction_id in self.full_node.full_node_store.pending_tx_request:
if transaction.transaction_id in self.full_node.full_node_store.peers_with_tx:
current_set = self.full_node.full_node_store.peers_with_tx[transaction.transaction_id]
if peer.peer_node_id in current_set:
return None
current_set.add(peer.peer_node_id)
current_map = self.full_node.full_node_store.peers_with_tx.get(transaction.transaction_id)
if current_map is None:
self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = {
peer.peer_node_id: PeerWithTx(
peer_host=peer.peer_info.host,
advertised_fee=transaction.fees,
advertised_cost=transaction.cost,
)
}
return None
else:
new_set = set()
new_set.add(peer.peer_node_id)
self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = new_set
prev = current_map.get(peer.peer_node_id)
if prev is not None:
if prev.advertised_fee != transaction.fees or prev.advertised_cost != transaction.cost:
self.log.warning(
f"Banning peer {peer.peer_node_id}. Sent us a new tx {transaction.transaction_id} with "
f"mismatch on cost {transaction.cost} vs previous advertised cost {prev.advertised_cost} "
f"and/or fee {transaction.fees} vs previous advertised fee {prev.advertised_fee}."
)
await peer.close(CONSENSUS_ERROR_BAN_SECONDS)
return None
current_map[peer.peer_node_id] = PeerWithTx(
peer_host=peer.peer_info.host, advertised_fee=transaction.fees, advertised_cost=transaction.cost
)
return None

self.full_node.full_node_store.pending_tx_request[transaction.transaction_id] = peer.peer_node_id
new_set = set()
new_set.add(peer.peer_node_id)
self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = new_set
self.full_node.full_node_store.peers_with_tx[transaction.transaction_id] = {
peer.peer_node_id: PeerWithTx(
peer_host=peer.peer_info.host, advertised_fee=transaction.fees, advertised_cost=transaction.cost
)
}

task_id: bytes32 = bytes32.secret()
fetch_task = create_referenced_task(
tx_request_and_timeout(self.full_node, transaction.transaction_id, task_id)
Expand Down Expand Up @@ -282,13 +302,15 @@ async def respond_transaction(
spend_name = std_hash(tx_bytes)
if spend_name in self.full_node.full_node_store.pending_tx_request:
self.full_node.full_node_store.pending_tx_request.pop(spend_name)
peers_with_tx = {}
if spend_name in self.full_node.full_node_store.peers_with_tx:
self.full_node.full_node_store.peers_with_tx.pop(spend_name)
peers_with_tx = self.full_node.full_node_store.peers_with_tx.pop(spend_name)

# TODO: Use fee in priority calculation, to prioritize high fee TXs
try:
await self.full_node.transaction_queue.put(
TransactionQueueEntry(tx.transaction, tx_bytes, spend_name, peer, test), peer.peer_node_id
TransactionQueueEntry(tx.transaction, tx_bytes, spend_name, peer, test, peers_with_tx),
peer.peer_node_id,
)
except TransactionQueueFull:
pass # we can't do anything here, the tx will be dropped. We might do something in the future.
Expand Down
5 changes: 4 additions & 1 deletion chia/full_node/full_node_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from chia.consensus.multiprocess_validation import PreValidationResult
from chia.consensus.pot_iterations import calculate_sp_interval_iters
from chia.consensus.signage_point import SignagePoint
from chia.full_node.tx_processing_queue import PeerWithTx
from chia.protocols import timelord_protocol
from chia.protocols.outbound_message import Message
from chia.types.blockchain_format.classgroup import ClassgroupElement
Expand Down Expand Up @@ -132,7 +133,9 @@ class FullNodeStore:
recent_eos: LRUCache[bytes32, tuple[EndOfSubSlotBundle, float]]

pending_tx_request: dict[bytes32, bytes32] # tx_id: peer_id
peers_with_tx: dict[bytes32, set[bytes32]] # tx_id: set[peer_ids}
# Map of transaction ID to the map of peer ID to its hostname, fee and cost
# it advertised for that transaction.
peers_with_tx: dict[bytes32, dict[bytes32, PeerWithTx]]
tx_fetch_tasks: dict[bytes32, asyncio.Task[None]] # Task id: task
serialized_wp_message: Optional[Message]
serialized_wp_message_tip: Optional[bytes32]
Expand Down
11 changes: 11 additions & 0 deletions chia/full_node/tx_processing_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from chia_rs import SpendBundle
from chia_rs.sized_bytes import bytes32
from chia_rs.sized_ints import uint64

from chia.server.ws_connection import WSChiaConnection
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
Expand Down Expand Up @@ -45,6 +46,13 @@ async def wait(self) -> T:
return self._value


@dataclasses.dataclass(frozen=True)
class PeerWithTx:
peer_host: str
advertised_fee: uint64
advertised_cost: uint64


@dataclass(frozen=True)
class TransactionQueueEntry:
"""
Expand All @@ -56,6 +64,9 @@ class TransactionQueueEntry:
spend_name: bytes32
peer: Optional[WSChiaConnection] = field(compare=False)
test: bool = field(compare=False)
# IDs of peers that advertised this transaction via new_transaction, along
# with their hostname, fee and cost.
peers_with_tx: dict[bytes32, PeerWithTx] = field(default_factory=dict, compare=False)
done: ValuedEvent[tuple[MempoolInclusionStatus, Optional[Err]]] = field(
default_factory=ValuedEvent,
compare=False,
Expand Down
17 changes: 9 additions & 8 deletions chia/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,14 +550,7 @@ async def connection_closed(
)
ban_time = 0
if ban_time > 0:
ban_until: float = time.time() + ban_time
self.log.warning(f"Banning {connection.peer_info.host} for {ban_time} seconds")
if connection.peer_info.host in self.banned_peers:
self.banned_peers[connection.peer_info.host] = max(
ban_until, self.banned_peers[connection.peer_info.host]
)
else:
self.banned_peers[connection.peer_info.host] = ban_until
self.ban_peer(connection.peer_info.host, ban_time)

present_connection = self.all_connections.get(connection.peer_node_id)
if present_connection is connection:
Expand Down Expand Up @@ -737,3 +730,11 @@ def is_trusted_peer(self, peer: WSChiaConnection, trusted_peers: dict[str, Any])

def set_capabilities(self, capabilities: list[tuple[uint16, str]]) -> None:
self._local_capabilities_for_handshake = capabilities

def ban_peer(self, host: str, ban_time: int) -> None:
ban_until: float = time.time() + ban_time
self.log.warning(f"Banning {host} for {ban_time} seconds")
if host in self.banned_peers:
self.banned_peers[host] = max(ban_until, self.banned_peers[host])
else:
self.banned_peers[host] = ban_until
Loading