|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from typing import Any |
| 4 | + |
3 | 5 | import pytest
|
4 | 6 | from chia_rs.sized_ints import uint32
|
5 | 7 |
|
|
9 | 11 | from chia.protocols.outbound_message import make_msg
|
10 | 12 | from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
11 | 13 | from chia.protocols.shared_protocol import Capability
|
12 |
| -from chia.server.rate_limit_numbers import compose_rate_limits, get_rate_limits_to_use |
| 14 | +from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use |
13 | 15 | from chia.server.rate_limit_numbers import rate_limits as rl_numbers
|
14 | 16 | from chia.server.rate_limits import RateLimiter
|
15 | 17 | from chia.server.server import ChiaServer
|
@@ -52,33 +54,89 @@ async def test_get_rate_limits_to_use():
|
52 | 54 | assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v1, rl_v2)
|
53 | 55 |
|
54 | 56 |
|
| 57 | +# we want to exercise every possibly limit we may hit |
| 58 | +# they are: |
| 59 | +# * total number of messages / 60 seconds for non-transaction messages |
| 60 | +# * total number of bytes / 60 seconds for non-transaction messages |
| 61 | +# * number of messages / 60 seconds for "transaction" messages |
| 62 | +# * number of bytes / 60 seconds for transaction messages |
| 63 | + |
| 64 | + |
55 | 65 | @pytest.mark.anyio
|
56 |
| -async def test_too_many_messages(mock_timer): |
57 |
| - # Too many messages |
58 |
| - r = RateLimiter(incoming=True) |
59 |
| - new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40)) |
60 |
| - for i in range(4999): |
61 |
| - assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None |
| 66 | +@pytest.mark.parametrize("incoming", [True, False]) |
| 67 | +@pytest.mark.parametrize("tx_msg", [True, False]) |
| 68 | +@pytest.mark.parametrize("cumulative_size", [True, False]) |
| 69 | +async def test_limits_v2(incoming: bool, tx_msg: bool, cumulative_size: bool, mock_timer, monkeypatch): |
| 70 | + # this test uses a single message type, and alters the rate limit settings |
| 71 | + # for it to hit the different cases |
| 72 | + print(f"{incoming=} {tx_msg=} {cumulative_size=}") |
| 73 | + |
| 74 | + count = 1000 |
| 75 | + message_data = b"0" * 1024 |
| 76 | + msg_type = ProtocolMessageTypes.new_transaction |
| 77 | + |
| 78 | + limits: dict[str, Any] = {} |
| 79 | + |
| 80 | + if cumulative_size: |
| 81 | + limits.update( |
| 82 | + { |
| 83 | + # this is the rate limit across all (non-tx) messages |
| 84 | + "non_tx_freq": 2000, |
| 85 | + # this is the byte size limit across all (non-tx) messages |
| 86 | + "non_tx_max_total_size": 1000 * 1024, |
| 87 | + } |
| 88 | + ) |
| 89 | + else: |
| 90 | + limits.update( |
| 91 | + { |
| 92 | + # this is the rate limit across all (non-tx) messages |
| 93 | + "non_tx_freq": 1000, |
| 94 | + # this is the byte size limit across all (non-tx) messages |
| 95 | + "non_tx_max_total_size": 100 * 1024 * 1024, |
| 96 | + } |
| 97 | + ) |
| 98 | + |
| 99 | + if cumulative_size: |
| 100 | + rate_limit = {msg_type: RLSettings(2000, 1024, 1000 * 1024)} |
| 101 | + else: |
| 102 | + rate_limit = {msg_type: RLSettings(1000, 1024, 1000 * 1024 * 1024)} |
| 103 | + |
| 104 | + if tx_msg: |
| 105 | + limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}}) |
| 106 | + else: |
| 107 | + limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}}) |
| 108 | + |
| 109 | + def mock_get_limits(*args, **kwargs) -> dict[str, Any]: |
| 110 | + return limits |
62 | 111 |
|
63 |
| - saw_disconnect = False |
64 |
| - for i in range(4999): |
65 |
| - response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) |
66 |
| - if response is not None: |
67 |
| - saw_disconnect = True |
68 |
| - assert saw_disconnect |
| 112 | + import chia |
69 | 113 |
|
70 |
| - # Non-tx message |
71 |
| - r = RateLimiter(incoming=True) |
72 |
| - new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40)) |
73 |
| - for i in range(200): |
74 |
| - assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None |
| 114 | + monkeypatch.setattr(chia.server.rate_limits, "get_rate_limits_to_use", mock_get_limits) |
75 | 115 |
|
76 |
| - saw_disconnect = False |
77 |
| - for i in range(200): |
78 |
| - response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) |
79 |
| - if response is not None: |
80 |
| - saw_disconnect = True |
81 |
| - assert saw_disconnect |
| 116 | + r = RateLimiter(incoming=incoming) |
| 117 | + msg = make_msg(ProtocolMessageTypes.new_transaction, message_data) |
| 118 | + |
| 119 | + for i in range(count): |
| 120 | + assert r.process_msg_and_check(msg, rl_v2, rl_v2) is None |
| 121 | + |
| 122 | + expected_msg = "" |
| 123 | + |
| 124 | + if cumulative_size: |
| 125 | + if not tx_msg: |
| 126 | + expected_msg += "non-tx size:" |
| 127 | + else: |
| 128 | + expected_msg += "cumulative size:" |
| 129 | + expected_msg += f" {(count + 1) * len(message_data)} > {count * len(message_data) * 1.0}" |
| 130 | + else: |
| 131 | + if not tx_msg: |
| 132 | + expected_msg += "non-tx count:" |
| 133 | + else: |
| 134 | + expected_msg += "message count:" |
| 135 | + expected_msg += f" {count + 1} > {count * 1.0}" |
| 136 | + expected_msg += " (scale factor: 1.0)" |
| 137 | + |
| 138 | + response = r.process_msg_and_check(msg, rl_v2, rl_v2) |
| 139 | + assert response == expected_msg |
82 | 140 |
|
83 | 141 |
|
84 | 142 | @pytest.mark.anyio
|
|
0 commit comments