Skip to content

Commit bb3d479

Browse files
committed
replace test_too_many_messages() with a parametrized test for (almost) all rate-limit cases
1 parent 3adfde1 commit bb3d479

File tree

1 file changed

+82
-24
lines changed

1 file changed

+82
-24
lines changed

chia/_tests/core/server/test_rate_limits.py

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import pytest
46
from chia_rs.sized_ints import uint32
57

@@ -9,7 +11,7 @@
911
from chia.protocols.outbound_message import make_msg
1012
from chia.protocols.protocol_message_types import ProtocolMessageTypes
1113
from chia.protocols.shared_protocol import Capability
12-
from chia.server.rate_limit_numbers import compose_rate_limits, get_rate_limits_to_use
14+
from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use
1315
from chia.server.rate_limit_numbers import rate_limits as rl_numbers
1416
from chia.server.rate_limits import RateLimiter
1517
from chia.server.server import ChiaServer
@@ -52,33 +54,89 @@ async def test_get_rate_limits_to_use():
5254
assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v1, rl_v2)
5355

5456

57+
# we want to exercise every possibly limit we may hit
58+
# they are:
59+
# * total number of messages / 60 seconds for non-transaction messages
60+
# * total number of bytes / 60 seconds for non-transaction messages
61+
# * number of messages / 60 seconds for "transaction" messages
62+
# * number of bytes / 60 seconds for transaction messages
63+
64+
5565
@pytest.mark.anyio
56-
async def test_too_many_messages(mock_timer):
57-
# Too many messages
58-
r = RateLimiter(incoming=True)
59-
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
60-
for i in range(4999):
61-
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None
66+
@pytest.mark.parametrize("incoming", [True, False])
67+
@pytest.mark.parametrize("tx_msg", [True, False])
68+
@pytest.mark.parametrize("cumulative_size", [True, False])
69+
async def test_limits_v2(incoming: bool, tx_msg: bool, cumulative_size: bool, mock_timer, monkeypatch):
70+
# this test uses a single message type, and alters the rate limit settings
71+
# for it to hit the different cases
72+
print(f"{incoming=} {tx_msg=} {cumulative_size=}")
73+
74+
count = 1000
75+
message_data = b"0" * 1024
76+
msg_type = ProtocolMessageTypes.new_transaction
77+
78+
limits: dict[str, Any] = {}
79+
80+
if cumulative_size:
81+
limits.update(
82+
{
83+
# this is the rate limit across all (non-tx) messages
84+
"non_tx_freq": 2000,
85+
# this is the byte size limit across all (non-tx) messages
86+
"non_tx_max_total_size": 1000 * 1024,
87+
}
88+
)
89+
else:
90+
limits.update(
91+
{
92+
# this is the rate limit across all (non-tx) messages
93+
"non_tx_freq": 1000,
94+
# this is the byte size limit across all (non-tx) messages
95+
"non_tx_max_total_size": 100 * 1024 * 1024,
96+
}
97+
)
98+
99+
if cumulative_size:
100+
rate_limit = {msg_type: RLSettings(2000, 1024, 1000 * 1024)}
101+
else:
102+
rate_limit = {msg_type: RLSettings(1000, 1024, 1000 * 1024 * 1024)}
103+
104+
if tx_msg:
105+
limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}})
106+
else:
107+
limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}})
108+
109+
def mock_get_limits(*args, **kwargs) -> dict[str, Any]:
110+
return limits
62111

63-
saw_disconnect = False
64-
for i in range(4999):
65-
response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
66-
if response is not None:
67-
saw_disconnect = True
68-
assert saw_disconnect
112+
import chia
69113

70-
# Non-tx message
71-
r = RateLimiter(incoming=True)
72-
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
73-
for i in range(200):
74-
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None
114+
monkeypatch.setattr(chia.server.rate_limits, "get_rate_limits_to_use", mock_get_limits)
75115

76-
saw_disconnect = False
77-
for i in range(200):
78-
response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
79-
if response is not None:
80-
saw_disconnect = True
81-
assert saw_disconnect
116+
r = RateLimiter(incoming=incoming)
117+
msg = make_msg(ProtocolMessageTypes.new_transaction, message_data)
118+
119+
for i in range(count):
120+
assert r.process_msg_and_check(msg, rl_v2, rl_v2) is None
121+
122+
expected_msg = ""
123+
124+
if cumulative_size:
125+
if not tx_msg:
126+
expected_msg += "non-tx size:"
127+
else:
128+
expected_msg += "cumulative size:"
129+
expected_msg += f" {(count + 1) * len(message_data)} > {count * len(message_data) * 1.0}"
130+
else:
131+
if not tx_msg:
132+
expected_msg += "non-tx count:"
133+
else:
134+
expected_msg += "message count:"
135+
expected_msg += f" {count + 1} > {count * 1.0}"
136+
expected_msg += " (scale factor: 1.0)"
137+
138+
response = r.process_msg_and_check(msg, rl_v2, rl_v2)
139+
assert response == expected_msg
82140

83141

84142
@pytest.mark.anyio

0 commit comments

Comments
 (0)