Skip to content

Commit 6bb5a7f

Browse files
committed
simplify rate_limit_numbers by flattening the structure. just map message type -> rate limits
1 parent 220f375 commit 6bb5a7f

File tree

3 files changed

+186
-259
lines changed

3 files changed

+186
-259
lines changed

chia/_tests/core/server/test_rate_limits.py

Lines changed: 17 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import Union
44

55
import pytest
66
from chia_rs.sized_ints import uint32
@@ -11,8 +11,7 @@
1111
from chia.protocols.outbound_message import make_msg
1212
from chia.protocols.protocol_message_types import ProtocolMessageTypes
1313
from chia.protocols.shared_protocol import Capability
14-
from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use
15-
from chia.server.rate_limit_numbers import rate_limits as rl_numbers
14+
from chia.server.rate_limit_numbers import RLSettings, Unlimited, get_rate_limits_to_use
1615
from chia.server.rate_limits import RateLimiter
1716
from chia.server.server import ChiaServer
1817
from chia.server.ws_connection import WSChiaConnection
@@ -63,47 +62,27 @@ async def test_limits_v2(incoming: bool, tx_msg: bool, cumulative_size: bool, mo
6362
message_data = b"0" * 1024
6463
msg_type = ProtocolMessageTypes.new_transaction
6564

66-
limits: dict[str, Any] = {}
67-
6865
if cumulative_size:
69-
limits.update(
70-
{
71-
# this is the rate limit across all (non-tx) messages
72-
"non_tx_freq": 2000,
73-
# this is the byte size limit across all (non-tx) messages
74-
"non_tx_max_total_size": 1000 * 1024,
75-
}
76-
)
66+
agg_limit = RLSettings(False, 2000, 1000 * 1024, 1000 * 1024)
7767
else:
78-
limits.update(
79-
{
80-
# this is the rate limit across all (non-tx) messages
81-
"non_tx_freq": 1000,
82-
# this is the byte size limit across all (non-tx) messages
83-
"non_tx_max_total_size": 100 * 1024 * 1024,
84-
}
85-
)
68+
agg_limit = RLSettings(False, 1000, 100 * 1024 * 1024, 100 * 1024 * 1024)
8669

70+
limits: dict[ProtocolMessageTypes, Union[RLSettings, Unlimited]]
8771
if cumulative_size:
88-
rate_limit = {msg_type: RLSettings(2000, 1024, 1000 * 1024)}
89-
else:
90-
rate_limit = {msg_type: RLSettings(1000, 1024, 1000 * 1024 * 1024)}
91-
92-
if tx_msg:
93-
limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}})
72+
limits = {msg_type: RLSettings(not tx_msg, 2000, 1024, 1000 * 1024)}
9473
else:
95-
limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}})
74+
limits = {msg_type: RLSettings(not tx_msg, 1000, 1024, 1000 * 1024 * 1024)}
9675

97-
def mock_get_limits(*args, **kwargs) -> dict[str, Any]:
98-
return limits
76+
def mock_get_limits(*args, **kwargs) -> tuple[dict[ProtocolMessageTypes, Union[RLSettings, Unlimited]], RLSettings]:
77+
return limits, agg_limit
9978

10079
import chia
10180

10281
monkeypatch.setattr(chia.server.rate_limits, "get_rate_limits_to_use", mock_get_limits)
10382

10483
timer = SimClock()
10584
r = RateLimiter(incoming=incoming, get_time=timer.monotonic)
106-
msg = make_msg(ProtocolMessageTypes.new_transaction, message_data)
85+
msg = make_msg(msg_type, message_data)
10786

10887
for i in range(count):
10988
assert r.process_msg_and_check(msg, rl_v2, rl_v2) is None
@@ -318,18 +297,18 @@ async def test_too_many_outgoing_messages():
318297
timer = SimClock()
319298
r = RateLimiter(incoming=False, get_time=timer.monotonic)
320299
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
321-
non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"]
300+
_, agg_limit = get_rate_limits_to_use(rl_v2, rl_v2)
322301

323302
passed = 0
324303
blocked = 0
325-
for i in range(non_tx_freq):
304+
for i in range(agg_limit.frequency):
326305
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
327306
passed += 1
328307
else:
329308
blocked += 1
330309

331310
assert passed == 10
332-
assert blocked == non_tx_freq - passed
311+
assert blocked == agg_limit.frequency - passed
333312

334313
# ensure that *another* message type is not blocked because of this
335314

@@ -343,18 +322,18 @@ async def test_too_many_incoming_messages():
343322
timer = SimClock()
344323
r = RateLimiter(incoming=True, get_time=timer.monotonic)
345324
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
346-
non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"]
325+
_, agg_limit = get_rate_limits_to_use(rl_v2, rl_v2)
347326

348327
passed = 0
349328
blocked = 0
350-
for i in range(non_tx_freq):
329+
for i in range(agg_limit.frequency):
351330
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
352331
passed += 1
353332
else:
354333
blocked += 1
355334

356335
assert passed == 10
357-
assert blocked == non_tx_freq - passed
336+
assert blocked == agg_limit.frequency - passed
358337

359338
# ensure that other message types *are* blocked because of this
360339

@@ -426,39 +405,13 @@ async def test_different_versions(node_with_params, node_with_params_b, self_hos
426405
# The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx",
427406
# which would mean the same rate limits are always used. This should not happen, since two nodes with V2
428407
# will use V2.
429-
total_tx_msg_count = len(
430-
get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities)["rate_limits_tx"]
431-
)
408+
total_tx_msg_count = len(get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities))
432409

433410
test_different_versions_results.append(total_tx_msg_count)
434411
if len(test_different_versions_results) >= 4:
435412
assert len(set(test_different_versions_results)) >= 2
436413

437414

438-
@pytest.mark.anyio
439-
async def test_compose():
440-
rl_1 = rl_numbers[1]
441-
rl_2 = rl_numbers[2]
442-
assert ProtocolMessageTypes.respond_children in rl_1["rate_limits_other"]
443-
assert ProtocolMessageTypes.respond_children not in rl_1["rate_limits_tx"]
444-
assert ProtocolMessageTypes.respond_children not in rl_2["rate_limits_other"]
445-
assert ProtocolMessageTypes.respond_children in rl_2["rate_limits_tx"]
446-
447-
assert ProtocolMessageTypes.request_block in rl_1["rate_limits_other"]
448-
assert ProtocolMessageTypes.request_block not in rl_1["rate_limits_tx"]
449-
assert ProtocolMessageTypes.request_block not in rl_2["rate_limits_other"]
450-
assert ProtocolMessageTypes.request_block not in rl_2["rate_limits_tx"]
451-
452-
comps = compose_rate_limits(rl_1, rl_2)
453-
# v2 limits are used if present
454-
assert ProtocolMessageTypes.respond_children not in comps["rate_limits_other"]
455-
assert ProtocolMessageTypes.respond_children in comps["rate_limits_tx"]
456-
457-
# Otherwise, fall back to v1
458-
assert ProtocolMessageTypes.request_block in rl_1["rate_limits_other"]
459-
assert ProtocolMessageTypes.request_block not in rl_1["rate_limits_tx"]
460-
461-
462415
@pytest.mark.anyio
463416
@pytest.mark.parametrize(
464417
"msg_type, size",

0 commit comments

Comments
 (0)