Skip to content

Commit 3c4c739

Browse files
committed
replace test_too_many_messages() with a parametrized test for (almost) all rate-limit cases
1 parent 0eb7527 commit 3c4c739

File tree

1 file changed

+82
-24
lines changed

1 file changed

+82
-24
lines changed

chia/_tests/core/server/test_rate_limits.py

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import pytest
46
from chia_rs.sized_ints import uint32
57

@@ -9,7 +11,7 @@
911
from chia.protocols.outbound_message import make_msg
1012
from chia.protocols.protocol_message_types import ProtocolMessageTypes
1113
from chia.protocols.shared_protocol import Capability
12-
from chia.server.rate_limit_numbers import compose_rate_limits, get_rate_limits_to_use
14+
from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use
1315
from chia.server.rate_limit_numbers import rate_limits as rl_numbers
1416
from chia.server.rate_limits import RateLimiter
1517
from chia.server.server import ChiaServer
@@ -40,34 +42,90 @@ async def test_get_rate_limits_to_use():
4042
assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v1, rl_v2)
4143

4244

45+
# we want to exercise every possibly limit we may hit
46+
# they are:
47+
# * total number of messages / 60 seconds for non-transaction messages
48+
# * total number of bytes / 60 seconds for non-transaction messages
49+
# * number of messages / 60 seconds for "transaction" messages
50+
# * number of bytes / 60 seconds for transaction messages
51+
52+
4353
@pytest.mark.anyio
44-
async def test_too_many_messages():
45-
# Too many messages
54+
@pytest.mark.parametrize("incoming", [True, False])
55+
@pytest.mark.parametrize("tx_msg", [True, False])
56+
@pytest.mark.parametrize("cumulative_size", [True, False])
57+
async def test_limits_v2(incoming: bool, tx_msg: bool, cumulative_size: bool, monkeypatch):
58+
# this test uses a single message type, and alters the rate limit settings
59+
# for it to hit the different cases
60+
print(f"{incoming=} {tx_msg=} {cumulative_size=}")
61+
62+
count = 1000
63+
message_data = b"0" * 1024
64+
msg_type = ProtocolMessageTypes.new_transaction
65+
66+
limits: dict[str, Any] = {}
67+
68+
if cumulative_size:
69+
limits.update(
70+
{
71+
# this is the rate limit across all (non-tx) messages
72+
"non_tx_freq": 2000,
73+
# this is the byte size limit across all (non-tx) messages
74+
"non_tx_max_total_size": 1000 * 1024,
75+
}
76+
)
77+
else:
78+
limits.update(
79+
{
80+
# this is the rate limit across all (non-tx) messages
81+
"non_tx_freq": 1000,
82+
# this is the byte size limit across all (non-tx) messages
83+
"non_tx_max_total_size": 100 * 1024 * 1024,
84+
}
85+
)
86+
87+
if cumulative_size:
88+
rate_limit = {msg_type: RLSettings(2000, 1024, 1000 * 1024)}
89+
else:
90+
rate_limit = {msg_type: RLSettings(1000, 1024, 1000 * 1024 * 1024)}
91+
92+
if tx_msg:
93+
limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}})
94+
else:
95+
limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}})
96+
97+
def mock_get_limits(*args, **kwargs) -> dict[str, Any]:
98+
return limits
99+
100+
import chia
101+
102+
monkeypatch.setattr(chia.server.rate_limits, "get_rate_limits_to_use", mock_get_limits)
103+
46104
timer = SimClock()
47-
r = RateLimiter(incoming=True, get_time=timer.monotonic)
48-
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
49-
for i in range(4999):
50-
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None
105+
r = RateLimiter(incoming=incoming, get_time=timer.monotonic)
106+
msg = make_msg(ProtocolMessageTypes.new_transaction, message_data)
51107

52-
saw_disconnect = False
53-
for i in range(4999):
54-
response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
55-
if response is not None:
56-
saw_disconnect = True
57-
assert saw_disconnect
108+
for i in range(count):
109+
assert r.process_msg_and_check(msg, rl_v2, rl_v2) is None
58110

59-
# Non-tx message
60-
r = RateLimiter(incoming=True, get_time=timer.monotonic)
61-
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
62-
for i in range(200):
63-
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None
111+
expected_msg = ""
64112

65-
saw_disconnect = False
66-
for i in range(200):
67-
response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
68-
if response is not None:
69-
saw_disconnect = True
70-
assert saw_disconnect
113+
if cumulative_size:
114+
if not tx_msg:
115+
expected_msg += "non-tx size:"
116+
else:
117+
expected_msg += "cumulative size:"
118+
expected_msg += f" {(count + 1) * len(message_data)} > {count * len(message_data) * 1.0}"
119+
else:
120+
if not tx_msg:
121+
expected_msg += "non-tx count:"
122+
else:
123+
expected_msg += "message count:"
124+
expected_msg += f" {count + 1} > {count * 1.0}"
125+
expected_msg += " (scale factor: 1.0)"
126+
127+
response = r.process_msg_and_check(msg, rl_v2, rl_v2)
128+
assert response == expected_msg
71129

72130

73131
@pytest.mark.anyio

0 commit comments

Comments
 (0)