diff --git a/chia/_tests/core/server/test_rate_limits.py b/chia/_tests/core/server/test_rate_limits.py index 080114c08af3..e66c9b77c59c 100644 --- a/chia/_tests/core/server/test_rate_limits.py +++ b/chia/_tests/core/server/test_rate_limits.py @@ -1,17 +1,19 @@ from __future__ import annotations -import asyncio +from dataclasses import dataclass +from typing import Any import pytest from chia_rs.sized_ints import uint32 from chia._tests.conftest import node_with_params +from chia._tests.util.misc import boolean_datacases from chia._tests.util.time_out_assert import time_out_assert from chia.protocols.full_node_protocol import RejectBlock, RejectBlocks, RespondBlock, RespondBlocks from chia.protocols.outbound_message import make_msg from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.protocols.shared_protocol import Capability -from chia.server.rate_limit_numbers import compose_rate_limits, get_rate_limits_to_use +from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use from chia.server.rate_limit_numbers import rate_limits as rl_numbers from chia.server.rate_limits import RateLimiter from chia.server.server import ChiaServer @@ -25,355 +27,442 @@ test_different_versions_results: list[int] = [] -class TestRateLimits: - @pytest.mark.anyio - async def test_get_rate_limits_to_use(self): - assert get_rate_limits_to_use(rl_v2, rl_v2) != get_rate_limits_to_use(rl_v2, rl_v1) - assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v2, rl_v1) - assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v1, rl_v2) - - @pytest.mark.anyio - async def test_too_many_messages(self): - # Too many messages - r = RateLimiter(incoming=True) - new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40)) - for i in range(4999): - assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(4999): - response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - # Non-tx message - r = RateLimiter(incoming=True) - new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40)) - for i in range(200): - assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(200): - response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - @pytest.mark.anyio - async def test_large_message(self): - # Large tx - small_tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) - large_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 3 * 1024 * 1024)) - - r = RateLimiter(incoming=True) - assert r.process_msg_and_check(small_tx_message, rl_v2, rl_v2) is None - assert r.process_msg_and_check(large_tx_message, rl_v2, rl_v2) is not None - - small_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 5 * 1024)) - large_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 600 * 1024)) - large_blocks_message = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 51 * 1024 * 1024)) - r = RateLimiter(incoming=True) - assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2) is None - assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2) is None - assert r.process_msg_and_check(large_vdf_message, rl_v2, rl_v2) is not None - # this limit applies even though this message type is unlimited - assert r.process_msg_and_check(large_blocks_message, rl_v2, rl_v2) is not None - - @pytest.mark.anyio - async def test_too_much_data(self): - # Too much data - r = RateLimiter(incoming=True) - tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) - for i in range(40): - assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(300): - response = r.process_msg_and_check(tx_message, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - r = RateLimiter(incoming=True) - block_message = make_msg(ProtocolMessageTypes.respond_unfinished_block, bytes([1] * 1024 * 1024)) - for i in range(10): - assert r.process_msg_and_check(block_message, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(40): - response = r.process_msg_and_check(block_message, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - @pytest.mark.anyio - async def test_non_tx_aggregate_limits(self): - # Frequency limits - r = RateLimiter(incoming=True) - message_1 = make_msg(ProtocolMessageTypes.coin_state_update, bytes([1] * 32)) - message_2 = make_msg(ProtocolMessageTypes.request_blocks, bytes([1] * 64)) - message_3 = make_msg(ProtocolMessageTypes.plot_sync_start, bytes([1] * 64)) - - for i in range(500): - assert r.process_msg_and_check(message_1, rl_v2, rl_v2) is None - - for i in range(500): - assert r.process_msg_and_check(message_2, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(500): - response = r.process_msg_and_check(message_3, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - # Size limits - r = RateLimiter(incoming=True) - message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 49 * 1024 * 1024)) - message_5 = make_msg(ProtocolMessageTypes.request_blocks, bytes([1] * 49 * 1024 * 1024)) - - for i in range(2): - assert r.process_msg_and_check(message_4, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(2): - response = r.process_msg_and_check(message_5, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - @pytest.mark.anyio - async def test_periodic_reset(self): - r = RateLimiter(True, 5) - tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) - for i in range(10): - assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(300): - response = r.process_msg_and_check(tx_message, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is not None - await asyncio.sleep(6) +@dataclass +class SimClock: + current_time: float = 1000.0 + + def monotonic(self) -> float: + return self.current_time + + def advance(self, duration: float) -> None: + self.current_time += duration + + +@pytest.mark.anyio +async def test_get_rate_limits_to_use(): + assert get_rate_limits_to_use(rl_v2, rl_v2) != get_rate_limits_to_use(rl_v2, rl_v1) + assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v2, rl_v1) + assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v1, rl_v2) + + +# we want to exercise every possibly limit we may hit +# they are: +# * total number of messages / 60 seconds for non-transaction messages +# * total number of bytes / 60 seconds for non-transaction messages +# * number of messages / 60 seconds for "transaction" messages +# * number of bytes / 60 seconds for transaction messages + + +@pytest.mark.anyio +@boolean_datacases(name="incoming", true="incoming", false="outgoing") +@boolean_datacases(name="tx_msg", true="tx", false="non-tx") +@boolean_datacases(name="limit_size", true="size-limit", false="count-limit") +async def test_limits_v2(incoming: bool, tx_msg: bool, limit_size: bool, monkeypatch: pytest.MonkeyPatch): + # this test uses a single message type, and alters the rate limit settings + # for it to hit the different cases + + count = 1000 + message_data = b"\0" * 1024 + msg_type = ProtocolMessageTypes.new_transaction + + limits: dict[str, Any] = {} + + if limit_size: + limits.update( + { + # this is the rate limit across all (non-tx) messages + "non_tx_freq": count * 2, + # this is the byte size limit across all (non-tx) messages + "non_tx_max_total_size": count * len(message_data), + } + ) + else: + limits.update( + { + # this is the rate limit across all (non-tx) messages + "non_tx_freq": count, + # this is the byte size limit across all (non-tx) messages + "non_tx_max_total_size": count * 2 * len(message_data), + } + ) + + if limit_size: + rate_limit = {msg_type: RLSettings(count * 2, 1024, count * len(message_data))} + else: + rate_limit = {msg_type: RLSettings(count, 1024, count * 2 * len(message_data))} + + if tx_msg: + limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}}) + else: + limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}}) + + def mock_get_limits(*args, **kwargs) -> dict[str, Any]: + return limits + + import chia.server.rate_limits + + monkeypatch.setattr(chia.server.rate_limits, "get_rate_limits_to_use", mock_get_limits) + + r = RateLimiter(incoming=incoming, get_time=lambda: 0) + msg = make_msg(msg_type, message_data) + + for i in range(count): + assert r.process_msg_and_check(msg, rl_v2, rl_v2) is None + + expected_msg = "" + + if limit_size: + if not tx_msg: + expected_msg += "non-tx size:" + else: + expected_msg += "cumulative size:" + expected_msg += f" {(count + 1) * len(message_data)} > {count * len(message_data) * 1.0}" + else: + if not tx_msg: + expected_msg += "non-tx count:" + else: + expected_msg += "message count:" + expected_msg += f" {count + 1} > {count * 1.0}" + expected_msg += " (scale factor: 1.0)" + + response = r.process_msg_and_check(msg, rl_v2, rl_v2) + assert response == expected_msg + + for _ in range(10): + response = r.process_msg_and_check(msg, rl_v2, rl_v2) + # we can't stop incoming messages from arriving, counters keep + # increasing for incoming messages. For outgoing messages, we expect + # them not to be sent when hitting the rate limit, so those counters in + # the returned message stay the same + if incoming: + assert response is not None + else: + assert response == expected_msg + + +@pytest.mark.anyio +async def test_large_message(): + # Large tx + small_tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) + large_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 3 * 1024 * 1024)) + + r = RateLimiter(incoming=True, get_time=lambda: 0) + assert r.process_msg_and_check(small_tx_message, rl_v2, rl_v2) is None + assert r.process_msg_and_check(large_tx_message, rl_v2, rl_v2) is not None + + small_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 5 * 1024)) + large_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 600 * 1024)) + large_blocks_message = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 51 * 1024 * 1024)) + r = RateLimiter(incoming=True, get_time=lambda: 0) + assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2) is None + assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2) is None + assert r.process_msg_and_check(large_vdf_message, rl_v2, rl_v2) is not None + # this limit applies even though this message type is unlimited + assert r.process_msg_and_check(large_blocks_message, rl_v2, rl_v2) is not None + + +@pytest.mark.anyio +async def test_too_much_data(): + # Too much data + r = RateLimiter(incoming=True, get_time=lambda: 0) + tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) + for i in range(40): + assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None + + saw_disconnect = False + for i in range(300): + response = r.process_msg_and_check(tx_message, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + + r = RateLimiter(incoming=True, get_time=lambda: 0) + block_message = make_msg(ProtocolMessageTypes.respond_unfinished_block, bytes([1] * 1024 * 1024)) + for i in range(10): + assert r.process_msg_and_check(block_message, rl_v2, rl_v2) is None + + saw_disconnect = False + for i in range(40): + response = r.process_msg_and_check(block_message, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + + +@pytest.mark.anyio +async def test_non_tx_aggregate_limits(): + # Frequency limits + r = RateLimiter(incoming=True, get_time=lambda: 0) + message_1 = make_msg(ProtocolMessageTypes.coin_state_update, bytes([1] * 32)) + message_2 = make_msg(ProtocolMessageTypes.request_blocks, bytes([1] * 64)) + message_3 = make_msg(ProtocolMessageTypes.plot_sync_start, bytes([1] * 64)) + + for i in range(500): + assert r.process_msg_and_check(message_1, rl_v2, rl_v2) is None + + for i in range(500): + assert r.process_msg_and_check(message_2, rl_v2, rl_v2) is None + + saw_disconnect = False + for i in range(500): + response = r.process_msg_and_check(message_3, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + + # Size limits + r = RateLimiter(incoming=True, get_time=lambda: 0) + message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 49 * 1024 * 1024)) + message_5 = make_msg(ProtocolMessageTypes.request_blocks, bytes([1] * 49 * 1024 * 1024)) + + for i in range(2): + assert r.process_msg_and_check(message_4, rl_v2, rl_v2) is None + + saw_disconnect = False + for i in range(2): + response = r.process_msg_and_check(message_5, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + + +@pytest.mark.anyio +async def test_periodic_reset(): + timer = SimClock() + r = RateLimiter(True, 5, get_time=timer.monotonic) + tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) + for i in range(10): assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None - # Counts reset also - r = RateLimiter(True, 5) - new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40)) - for i in range(4999): - assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(4999): - response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - await asyncio.sleep(6) + saw_disconnect = False + for i in range(300): + response = r.process_msg_and_check(tx_message, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is not None + timer.advance(6) + assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None + + # Counts reset also + r = RateLimiter(True, 5, get_time=timer.monotonic) + new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40)) + for i in range(4999): assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None - @pytest.mark.anyio - async def test_percentage_limits(self): - r = RateLimiter(True, 60, 40) - new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40)) - for i in range(50): - assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(50): - response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - r = RateLimiter(True, 60, 40) - block_message = make_msg(ProtocolMessageTypes.respond_unfinished_block, bytes([1] * 1024 * 1024)) - for i in range(5): - assert r.process_msg_and_check(block_message, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(5): - response = r.process_msg_and_check(block_message, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - # Aggregate percentage limit count - r = RateLimiter(True, 60, 40) - message_1 = make_msg(ProtocolMessageTypes.coin_state_update, bytes([1] * 5)) - message_2 = make_msg(ProtocolMessageTypes.request_blocks, bytes([1] * 32)) - message_3 = make_msg(ProtocolMessageTypes.plot_sync_start, bytes([1] * 32)) - - for i in range(180): - assert r.process_msg_and_check(message_1, rl_v2, rl_v2) is None - for i in range(180): - assert r.process_msg_and_check(message_2, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(100): - response = r.process_msg_and_check(message_3, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - # Aggregate percentage limit max total size - r = RateLimiter(True, 60, 40) - message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 18 * 1024 * 1024)) - message_5 = make_msg(ProtocolMessageTypes.respond_unfinished_block, bytes([1] * 24 * 1024 * 1024)) - - for i in range(2): - assert r.process_msg_and_check(message_4, rl_v2, rl_v2) is None - - saw_disconnect = False - for i in range(2): - response = r.process_msg_and_check(message_5, rl_v2, rl_v2) - if response is not None: - saw_disconnect = True - assert saw_disconnect - - @pytest.mark.anyio - async def test_too_many_outgoing_messages(self): - # Too many messages - r = RateLimiter(incoming=False) - new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1])) - non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"] - - passed = 0 - blocked = 0 - for i in range(non_tx_freq): - if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None: - passed += 1 - else: - blocked += 1 - - assert passed == 10 - assert blocked == non_tx_freq - passed - - # ensure that *another* message type is not blocked because of this - - new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1])) - assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2) is None - - @pytest.mark.anyio - async def test_too_many_incoming_messages(self): - # Too many messages - r = RateLimiter(incoming=True) - new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1])) - non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"] - - passed = 0 - blocked = 0 - for i in range(non_tx_freq): - if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None: - passed += 1 - else: - blocked += 1 - - assert passed == 10 - assert blocked == non_tx_freq - passed - - # ensure that other message types *are* blocked because of this - - new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1])) - assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2) is not None - - @pytest.mark.parametrize( - "node_with_params", - [ - pytest.param( - dict( - disable_capabilities=[Capability.BLOCK_HEADERS, Capability.RATE_LIMITS_V2], - ), - id="V1", + saw_disconnect = False + for i in range(4999): + response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + timer.advance(6) + assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None + + +@pytest.mark.anyio +async def test_percentage_limits(): + r = RateLimiter(True, 60, 40, get_time=lambda: 0) + new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40)) + for i in range(50): + assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None + + saw_disconnect = False + for i in range(50): + response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + + r = RateLimiter(True, 60, 40, get_time=lambda: 0) + block_message = make_msg(ProtocolMessageTypes.respond_unfinished_block, bytes([1] * 1024 * 1024)) + for i in range(5): + assert r.process_msg_and_check(block_message, rl_v2, rl_v2) is None + + saw_disconnect = False + for i in range(5): + response = r.process_msg_and_check(block_message, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + + # Aggregate percentage limit count + r = RateLimiter(True, 60, 40, get_time=lambda: 0) + message_1 = make_msg(ProtocolMessageTypes.coin_state_update, bytes([1] * 5)) + message_2 = make_msg(ProtocolMessageTypes.request_blocks, bytes([1] * 32)) + message_3 = make_msg(ProtocolMessageTypes.plot_sync_start, bytes([1] * 32)) + + for i in range(180): + assert r.process_msg_and_check(message_1, rl_v2, rl_v2) is None + for i in range(180): + assert r.process_msg_and_check(message_2, rl_v2, rl_v2) is None + + saw_disconnect = False + for i in range(100): + response = r.process_msg_and_check(message_3, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + + # Aggregate percentage limit max total size + r = RateLimiter(True, 60, 40, get_time=lambda: 0) + message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 18 * 1024 * 1024)) + message_5 = make_msg(ProtocolMessageTypes.respond_unfinished_block, bytes([1] * 24 * 1024 * 1024)) + + for i in range(2): + assert r.process_msg_and_check(message_4, rl_v2, rl_v2) is None + + saw_disconnect = False + for i in range(2): + response = r.process_msg_and_check(message_5, rl_v2, rl_v2) + if response is not None: + saw_disconnect = True + assert saw_disconnect + + +@pytest.mark.anyio +async def test_too_many_outgoing_messages(): + # Too many messages + r = RateLimiter(incoming=False, get_time=lambda: 0) + new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1])) + non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"] + + passed = 0 + blocked = 0 + for i in range(non_tx_freq): + if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None: + passed += 1 + else: + blocked += 1 + + assert passed == 10 + assert blocked == non_tx_freq - passed + + # ensure that *another* message type is not blocked because of this + + new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1])) + assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2) is None + + +@pytest.mark.anyio +async def test_too_many_incoming_messages(): + # Too many messages + r = RateLimiter(incoming=True, get_time=lambda: 0) + new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1])) + non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"] + + passed = 0 + blocked = 0 + for i in range(non_tx_freq): + if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None: + passed += 1 + else: + blocked += 1 + + assert passed == 10 + assert blocked == non_tx_freq - passed + + # ensure that other message types *are* blocked because of this + + new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1])) + assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2) is not None + + +@pytest.mark.parametrize( + "node_with_params", + [ + pytest.param( + dict( + disable_capabilities=[Capability.BLOCK_HEADERS, Capability.RATE_LIMITS_V2], ), - pytest.param( - dict( - disable_capabilities=[], - ), - id="V2", + id="V1", + ), + pytest.param( + dict( + disable_capabilities=[], ), - ], - indirect=True, - ) - @pytest.mark.parametrize( - "node_with_params_b", - [ - pytest.param( - dict( - disable_capabilities=[Capability.BLOCK_HEADERS, Capability.RATE_LIMITS_V2], - ), - id="V1", + id="V2", + ), + ], + indirect=True, +) +@pytest.mark.parametrize( + "node_with_params_b", + [ + pytest.param( + dict( + disable_capabilities=[Capability.BLOCK_HEADERS, Capability.RATE_LIMITS_V2], ), - pytest.param( - dict( - disable_capabilities=[], - ), - id="V2", + id="V1", + ), + pytest.param( + dict( + disable_capabilities=[], ), - ], - indirect=True, - ) - @pytest.mark.anyio - @pytest.mark.limit_consensus_modes(reason="save time") - async def test_different_versions(self, node_with_params, node_with_params_b, self_hostname): - node_a = node_with_params - node_b = node_with_params_b + id="V2", + ), + ], + indirect=True, +) +@pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="save time") +async def test_different_versions(node_with_params, node_with_params_b, self_hostname): + node_a = node_with_params + node_b = node_with_params_b - full_node_server_a: ChiaServer = node_a.full_node.server - full_node_server_b: ChiaServer = node_b.full_node.server + full_node_server_a: ChiaServer = node_a.full_node.server + full_node_server_b: ChiaServer = node_b.full_node.server - await full_node_server_b.start_client(PeerInfo(self_hostname, full_node_server_a.get_port()), None) + await full_node_server_b.start_client(PeerInfo(self_hostname, full_node_server_a.get_port()), None) - assert len(full_node_server_b.get_connections()) == 1 - assert len(full_node_server_a.get_connections()) == 1 + assert len(full_node_server_b.get_connections()) == 1 + assert len(full_node_server_a.get_connections()) == 1 - a_con: WSChiaConnection = full_node_server_a.get_connections()[0] - b_con: WSChiaConnection = full_node_server_b.get_connections()[0] + a_con: WSChiaConnection = full_node_server_a.get_connections()[0] + b_con: WSChiaConnection = full_node_server_b.get_connections()[0] - print(a_con.local_capabilities, a_con.peer_capabilities) - print(b_con.local_capabilities, b_con.peer_capabilities) + print(a_con.local_capabilities, a_con.peer_capabilities) + print(b_con.local_capabilities, b_con.peer_capabilities) - # The two nodes will use the same rate limits even if their versions are different - assert get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities) == get_rate_limits_to_use( - b_con.local_capabilities, b_con.peer_capabilities - ) + # The two nodes will use the same rate limits even if their versions are different + assert get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities) == get_rate_limits_to_use( + b_con.local_capabilities, b_con.peer_capabilities + ) - # The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx", - # which would mean the same rate limits are always used. This should not happen, since two nodes with V2 - # will use V2. - total_tx_msg_count = len( - get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities)["rate_limits_tx"] - ) + # The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx", + # which would mean the same rate limits are always used. This should not happen, since two nodes with V2 + # will use V2. + total_tx_msg_count = len( + get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities)["rate_limits_tx"] + ) + + test_different_versions_results.append(total_tx_msg_count) + if len(test_different_versions_results) >= 4: + assert len(set(test_different_versions_results)) >= 2 - test_different_versions_results.append(total_tx_msg_count) - if len(test_different_versions_results) >= 4: - assert len(set(test_different_versions_results)) >= 2 - @pytest.mark.anyio - async def test_compose(self): - rl_1 = rl_numbers[1] - rl_2 = rl_numbers[2] - assert ProtocolMessageTypes.respond_children in rl_1["rate_limits_other"] - assert ProtocolMessageTypes.respond_children not in rl_1["rate_limits_tx"] - assert ProtocolMessageTypes.respond_children not in rl_2["rate_limits_other"] - assert ProtocolMessageTypes.respond_children in rl_2["rate_limits_tx"] +@pytest.mark.anyio +async def test_compose(): + rl_1 = rl_numbers[1] + rl_2 = rl_numbers[2] + assert ProtocolMessageTypes.respond_children in rl_1["rate_limits_other"] + assert ProtocolMessageTypes.respond_children not in rl_1["rate_limits_tx"] + assert ProtocolMessageTypes.respond_children not in rl_2["rate_limits_other"] + assert ProtocolMessageTypes.respond_children in rl_2["rate_limits_tx"] - assert ProtocolMessageTypes.request_block in rl_1["rate_limits_other"] - assert ProtocolMessageTypes.request_block not in rl_1["rate_limits_tx"] - assert ProtocolMessageTypes.request_block not in rl_2["rate_limits_other"] - assert ProtocolMessageTypes.request_block not in rl_2["rate_limits_tx"] + assert ProtocolMessageTypes.request_block in rl_1["rate_limits_other"] + assert ProtocolMessageTypes.request_block not in rl_1["rate_limits_tx"] + assert ProtocolMessageTypes.request_block not in rl_2["rate_limits_other"] + assert ProtocolMessageTypes.request_block not in rl_2["rate_limits_tx"] - comps = compose_rate_limits(rl_1, rl_2) - # v2 limits are used if present - assert ProtocolMessageTypes.respond_children not in comps["rate_limits_other"] - assert ProtocolMessageTypes.respond_children in comps["rate_limits_tx"] + comps = compose_rate_limits(rl_1, rl_2) + # v2 limits are used if present + assert ProtocolMessageTypes.respond_children not in comps["rate_limits_other"] + assert ProtocolMessageTypes.respond_children in comps["rate_limits_tx"] - # Otherwise, fall back to v1 - assert ProtocolMessageTypes.request_block in rl_1["rate_limits_other"] - assert ProtocolMessageTypes.request_block not in rl_1["rate_limits_tx"] + # Otherwise, fall back to v1 + assert ProtocolMessageTypes.request_block in rl_1["rate_limits_other"] + assert ProtocolMessageTypes.request_block not in rl_1["rate_limits_tx"] @pytest.mark.anyio @@ -387,7 +476,7 @@ async def test_compose(self): ], ) async def test_unlimited(msg_type: ProtocolMessageTypes, size: int): - r = RateLimiter(incoming=False) + r = RateLimiter(incoming=False, get_time=lambda: 0) message = make_msg(msg_type, bytes([1] * size)) diff --git a/chia/server/rate_limit_numbers.py b/chia/server/rate_limit_numbers.py index 3edccf729370..b7eea59ae8c2 100644 --- a/chia/server/rate_limit_numbers.py +++ b/chia/server/rate_limit_numbers.py @@ -24,6 +24,8 @@ class RLSettings: # this class is used to indicate that a message type is not subject to a rate # limit, but just a per-message size limit. This may be appropriate for response # messages that are implicitly limited by their corresponding request message +# Unlimited message types are also not subject to the overall limit across all +# messages (just like messages in the "tx" category) @dataclasses.dataclass(frozen=True) class Unlimited: max_size: int # Max size of each request diff --git a/chia/server/rate_limits.py b/chia/server/rate_limits.py index e8670288dac6..685ed142517b 100644 --- a/chia/server/rate_limits.py +++ b/chia/server/rate_limits.py @@ -4,7 +4,7 @@ import logging import time from collections import Counter -from typing import Optional +from typing import Callable, Optional from chia.protocols.outbound_message import Message from chia.protocols.protocol_message_types import ProtocolMessageTypes @@ -18,14 +18,22 @@ class RateLimiter: incoming: bool reset_seconds: int - current_minute: int + current_slot: int message_counts: Counter[ProtocolMessageTypes] message_cumulative_sizes: Counter[ProtocolMessageTypes] percentage_of_limit: int non_tx_message_counts: int = 0 non_tx_cumulative_size: int = 0 + get_time: Callable[[], float] - def __init__(self, incoming: bool, reset_seconds: int = 60, percentage_of_limit: int = 100): + def __init__( + self, + incoming: bool, + reset_seconds: int = 60, + percentage_of_limit: int = 100, + *, + get_time: Callable[[], float] = time.monotonic, + ): """ The incoming parameter affects whether counters are incremented unconditionally or not. For incoming messages, the counters are always @@ -33,9 +41,10 @@ def __init__(self, incoming: bool, reset_seconds: int = 60, percentage_of_limit: if they are allowed to be sent by the rate limiter, since we won't send the messages otherwise. """ + self.get_time = get_time self.incoming = incoming self.reset_seconds = reset_seconds - self.current_minute = int(time.time() // reset_seconds) + self.current_slot = int(get_time() // reset_seconds) self.message_counts = Counter() self.message_cumulative_sizes = Counter() self.percentage_of_limit = percentage_of_limit @@ -51,9 +60,9 @@ def process_msg_and_check( hit and the message is good to be sent or received. """ - current_minute = int(time.time() // self.reset_seconds) - if current_minute != self.current_minute: - self.current_minute = current_minute + current_slot = int(self.get_time() // self.reset_seconds) + if current_slot != self.current_slot: + self.current_slot = current_slot self.message_counts = Counter() self.message_cumulative_sizes = Counter() self.non_tx_message_counts = 0 @@ -74,7 +83,7 @@ def process_msg_and_check( rate_limits = get_rate_limits_to_use(our_capabilities, peer_capabilities) try: - limits: RLSettings = rate_limits["default_settings"] + limits: RLSettings if message_type in rate_limits["rate_limits_tx"]: limits = rate_limits["rate_limits_tx"][message_type] elif message_type in rate_limits["rate_limits_other"]: @@ -104,6 +113,7 @@ def process_msg_and_check( log.warning( f"Message type {message_type} not found in rate limits (scale factor: {proportion_of_limit})", ) + limits = rate_limits["default_settings"] if isinstance(limits, Unlimited): # this message type is not rate limited. This is used for @@ -121,9 +131,9 @@ def process_msg_and_check( if new_message_counts > limits.frequency * proportion_of_limit: return " ".join( [ - f"message count: {new_message_counts}" - f"> {limits.frequency * proportion_of_limit}" - f"(scale factor: {proportion_of_limit})" + f"message count: {new_message_counts}", + f"> {limits.frequency * proportion_of_limit}", + f"(scale factor: {proportion_of_limit})", ] ) if len(message.data) > limits.max_size: