|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from typing import Any |
| 4 | + |
3 | 5 | import pytest
|
4 | 6 | from chia_rs.sized_ints import uint32
|
5 | 7 |
|
|
9 | 11 | from chia.protocols.outbound_message import make_msg
|
10 | 12 | from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
11 | 13 | from chia.protocols.shared_protocol import Capability
|
12 |
| -from chia.server.rate_limit_numbers import compose_rate_limits, get_rate_limits_to_use |
| 14 | +from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use |
13 | 15 | from chia.server.rate_limit_numbers import rate_limits as rl_numbers
|
14 | 16 | from chia.server.rate_limits import RateLimiter
|
15 | 17 | from chia.server.server import ChiaServer
|
@@ -40,34 +42,90 @@ async def test_get_rate_limits_to_use():
|
40 | 42 | assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v1, rl_v2)
|
41 | 43 |
|
42 | 44 |
|
| 45 | +# we want to exercise every possibly limit we may hit |
| 46 | +# they are: |
| 47 | +# * total number of messages / 60 seconds for non-transaction messages |
| 48 | +# * total number of bytes / 60 seconds for non-transaction messages |
| 49 | +# * number of messages / 60 seconds for "transaction" messages |
| 50 | +# * number of bytes / 60 seconds for transaction messages |
| 51 | + |
| 52 | + |
43 | 53 | @pytest.mark.anyio
|
44 |
| -async def test_too_many_messages(): |
45 |
| - # Too many messages |
| 54 | +@pytest.mark.parametrize("incoming", [True, False]) |
| 55 | +@pytest.mark.parametrize("tx_msg", [True, False]) |
| 56 | +@pytest.mark.parametrize("cumulative_size", [True, False]) |
| 57 | +async def test_limits_v2(incoming: bool, tx_msg: bool, cumulative_size: bool, monkeypatch): |
| 58 | + # this test uses a single message type, and alters the rate limit settings |
| 59 | + # for it to hit the different cases |
| 60 | + print(f"{incoming=} {tx_msg=} {cumulative_size=}") |
| 61 | + |
| 62 | + count = 1000 |
| 63 | + message_data = b"0" * 1024 |
| 64 | + msg_type = ProtocolMessageTypes.new_transaction |
| 65 | + |
| 66 | + limits: dict[str, Any] = {} |
| 67 | + |
| 68 | + if cumulative_size: |
| 69 | + limits.update( |
| 70 | + { |
| 71 | + # this is the rate limit across all (non-tx) messages |
| 72 | + "non_tx_freq": 2000, |
| 73 | + # this is the byte size limit across all (non-tx) messages |
| 74 | + "non_tx_max_total_size": 1000 * 1024, |
| 75 | + } |
| 76 | + ) |
| 77 | + else: |
| 78 | + limits.update( |
| 79 | + { |
| 80 | + # this is the rate limit across all (non-tx) messages |
| 81 | + "non_tx_freq": 1000, |
| 82 | + # this is the byte size limit across all (non-tx) messages |
| 83 | + "non_tx_max_total_size": 100 * 1024 * 1024, |
| 84 | + } |
| 85 | + ) |
| 86 | + |
| 87 | + if cumulative_size: |
| 88 | + rate_limit = {msg_type: RLSettings(2000, 1024, 1000 * 1024)} |
| 89 | + else: |
| 90 | + rate_limit = {msg_type: RLSettings(1000, 1024, 1000 * 1024 * 1024)} |
| 91 | + |
| 92 | + if tx_msg: |
| 93 | + limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}}) |
| 94 | + else: |
| 95 | + limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}}) |
| 96 | + |
| 97 | + def mock_get_limits(*args, **kwargs) -> dict[str, Any]: |
| 98 | + return limits |
| 99 | + |
| 100 | + import chia |
| 101 | + |
| 102 | + monkeypatch.setattr(chia.server.rate_limits, "get_rate_limits_to_use", mock_get_limits) |
| 103 | + |
46 | 104 | timer = SimClock()
|
47 |
| - r = RateLimiter(incoming=True, get_time=timer.monotonic) |
48 |
| - new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40)) |
49 |
| - for i in range(4999): |
50 |
| - assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None |
| 105 | + r = RateLimiter(incoming=incoming, get_time=timer.monotonic) |
| 106 | + msg = make_msg(ProtocolMessageTypes.new_transaction, message_data) |
51 | 107 |
|
52 |
| - saw_disconnect = False |
53 |
| - for i in range(4999): |
54 |
| - response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) |
55 |
| - if response is not None: |
56 |
| - saw_disconnect = True |
57 |
| - assert saw_disconnect |
| 108 | + for i in range(count): |
| 109 | + assert r.process_msg_and_check(msg, rl_v2, rl_v2) is None |
58 | 110 |
|
59 |
| - # Non-tx message |
60 |
| - r = RateLimiter(incoming=True, get_time=timer.monotonic) |
61 |
| - new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40)) |
62 |
| - for i in range(200): |
63 |
| - assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None |
| 111 | + expected_msg = "" |
64 | 112 |
|
65 |
| - saw_disconnect = False |
66 |
| - for i in range(200): |
67 |
| - response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) |
68 |
| - if response is not None: |
69 |
| - saw_disconnect = True |
70 |
| - assert saw_disconnect |
| 113 | + if cumulative_size: |
| 114 | + if not tx_msg: |
| 115 | + expected_msg += "non-tx size:" |
| 116 | + else: |
| 117 | + expected_msg += "cumulative size:" |
| 118 | + expected_msg += f" {(count + 1) * len(message_data)} > {count * len(message_data) * 1.0}" |
| 119 | + else: |
| 120 | + if not tx_msg: |
| 121 | + expected_msg += "non-tx count:" |
| 122 | + else: |
| 123 | + expected_msg += "message count:" |
| 124 | + expected_msg += f" {count + 1} > {count * 1.0}" |
| 125 | + expected_msg += " (scale factor: 1.0)" |
| 126 | + |
| 127 | + response = r.process_msg_and_check(msg, rl_v2, rl_v2) |
| 128 | + assert response == expected_msg |
71 | 129 |
|
72 | 130 |
|
73 | 131 | @pytest.mark.anyio
|
|
0 commit comments