Skip to content

Commit 3e88de7

Browse files
committed
simplify rate_limit_numbers by flattening the structure. just map message type -> rate limits
1 parent fc87630 commit 3e88de7

File tree

3 files changed

+187
-261
lines changed

3 files changed

+187
-261
lines changed

chia/_tests/core/server/test_rate_limits.py

Lines changed: 18 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import Any, cast
4+
from typing import Union
55

66
import pytest
77
from chia_rs.sized_ints import uint32
@@ -13,8 +13,7 @@
1313
from chia.protocols.outbound_message import make_msg
1414
from chia.protocols.protocol_message_types import ProtocolMessageTypes
1515
from chia.protocols.shared_protocol import Capability
16-
from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use
17-
from chia.server.rate_limit_numbers import rate_limits as rl_numbers
16+
from chia.server.rate_limit_numbers import RLSettings, Unlimited, get_rate_limits_to_use
1817
from chia.server.rate_limits import RateLimiter
1918
from chia.server.server import ChiaServer
2019
from chia.server.ws_connection import WSChiaConnection
@@ -66,39 +65,22 @@ async def test_limits_v2(incoming: bool, tx_msg: bool, limit_size: bool, monkeyp
6665
message_data = b"\0" * 1024
6766
msg_type = ProtocolMessageTypes.new_transaction
6867

69-
limits: dict[str, Any] = {}
68+
limits: dict[ProtocolMessageTypes, Union[RLSettings, Unlimited]]
7069

7170
if limit_size:
72-
limits.update(
73-
{
74-
# this is the rate limit across all (non-tx) messages
75-
"non_tx_freq": count * 2,
76-
# this is the byte size limit across all (non-tx) messages
77-
"non_tx_max_total_size": count * len(message_data),
78-
}
79-
)
71+
agg_limit = RLSettings(False, count * 2, 1024, count * len(message_data))
8072
else:
81-
limits.update(
82-
{
83-
# this is the rate limit across all (non-tx) messages
84-
"non_tx_freq": count,
85-
# this is the byte size limit across all (non-tx) messages
86-
"non_tx_max_total_size": count * 2 * len(message_data),
87-
}
88-
)
73+
agg_limit = RLSettings(False, count, 1024, count * 2 * len(message_data))
8974

9075
if limit_size:
91-
rate_limit = {msg_type: RLSettings(count * 2, 1024, count * len(message_data))}
76+
limits = {msg_type: RLSettings(not tx_msg, count * 2, 1024, count * len(message_data))}
9277
else:
93-
rate_limit = {msg_type: RLSettings(count, 1024, count * 2 * len(message_data))}
78+
limits = {msg_type: RLSettings(not tx_msg, count, 1024, count * 2 * len(message_data))}
9479

95-
if tx_msg:
96-
limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}})
97-
else:
98-
limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}})
99-
100-
def mock_get_limits(our_capabilities: list[Capability], peer_capabilities: list[Capability]) -> dict[str, Any]:
101-
return limits
80+
def mock_get_limits(
81+
our_capabilities: list[Capability], peer_capabilities: list[Capability]
82+
) -> tuple[dict[ProtocolMessageTypes, Union[RLSettings, Unlimited]], RLSettings]:
83+
return limits, agg_limit
10284

10385
import chia.server.rate_limits
10486

@@ -326,18 +308,18 @@ async def test_too_many_outgoing_messages() -> None:
326308
# Too many messages
327309
r = RateLimiter(incoming=False, get_time=lambda: 0)
328310
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
329-
non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"]
311+
_, agg_limit = get_rate_limits_to_use(rl_v2, rl_v2)
330312

331313
passed = 0
332314
blocked = 0
333-
for i in range(non_tx_freq):
315+
for i in range(agg_limit.frequency):
334316
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
335317
passed += 1
336318
else:
337319
blocked += 1
338320

339321
assert passed == 10
340-
assert blocked == non_tx_freq - passed
322+
assert blocked == agg_limit.frequency - passed
341323

342324
# ensure that *another* message type is not blocked because of this
343325

@@ -350,18 +332,18 @@ async def test_too_many_incoming_messages() -> None:
350332
# Too many messages
351333
r = RateLimiter(incoming=True, get_time=lambda: 0)
352334
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
353-
non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"]
335+
_, agg_limit = get_rate_limits_to_use(rl_v2, rl_v2)
354336

355337
passed = 0
356338
blocked = 0
357-
for i in range(non_tx_freq):
339+
for i in range(agg_limit.frequency):
358340
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
359341
passed += 1
360342
else:
361343
blocked += 1
362344

363345
assert passed == 10
364-
assert blocked == non_tx_freq - passed
346+
assert blocked == agg_limit.frequency - passed
365347

366348
# ensure that other message types *are* blocked because of this
367349

@@ -435,43 +417,13 @@ async def test_different_versions(
435417
# The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx",
436418
# which would mean the same rate limits are always used. This should not happen, since two nodes with V2
437419
# will use V2.
438-
total_tx_msg_count = len(
439-
get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities)["rate_limits_tx"]
440-
)
420+
total_tx_msg_count = len(get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities))
441421

442422
test_different_versions_results.append(total_tx_msg_count)
443423
if len(test_different_versions_results) >= 4:
444424
assert len(set(test_different_versions_results)) >= 2
445425

446426

447-
@pytest.mark.anyio
448-
async def test_compose() -> None:
449-
rl_1 = rl_numbers[1]
450-
rl_2 = rl_numbers[2]
451-
rl_1_rate_limits_other = cast(dict[ProtocolMessageTypes, RLSettings], rl_1["rate_limits_other"])
452-
rl_2_rate_limits_other = cast(dict[ProtocolMessageTypes, RLSettings], rl_2["rate_limits_other"])
453-
rl_1_rate_limits_tx = cast(dict[ProtocolMessageTypes, RLSettings], rl_1["rate_limits_tx"])
454-
rl_2_rate_limits_tx = cast(dict[ProtocolMessageTypes, RLSettings], rl_2["rate_limits_tx"])
455-
assert ProtocolMessageTypes.respond_children in rl_1_rate_limits_other
456-
assert ProtocolMessageTypes.respond_children not in rl_1_rate_limits_tx
457-
assert ProtocolMessageTypes.respond_children not in rl_2_rate_limits_other
458-
assert ProtocolMessageTypes.respond_children in rl_2_rate_limits_tx
459-
460-
assert ProtocolMessageTypes.request_block in rl_1_rate_limits_other
461-
assert ProtocolMessageTypes.request_block not in rl_1_rate_limits_tx
462-
assert ProtocolMessageTypes.request_block not in rl_2_rate_limits_other
463-
assert ProtocolMessageTypes.request_block not in rl_2_rate_limits_tx
464-
465-
comps = compose_rate_limits(rl_1, rl_2)
466-
# v2 limits are used if present
467-
assert ProtocolMessageTypes.respond_children not in comps["rate_limits_other"]
468-
assert ProtocolMessageTypes.respond_children in comps["rate_limits_tx"]
469-
470-
# Otherwise, fall back to v1
471-
assert ProtocolMessageTypes.request_block in rl_1_rate_limits_other
472-
assert ProtocolMessageTypes.request_block not in rl_1_rate_limits_tx
473-
474-
475427
@pytest.mark.anyio
476428
@pytest.mark.parametrize(
477429
"msg_type, size",

0 commit comments

Comments
 (0)