Skip to content

Commit c1d15df

Browse files
authored
Simplify rate limit numbers (#19991)
* remove unused default_settings in rate limits * simplify rate_limit_numbers by flattening the structure. just map message type -> rate limits * review comments
1 parent 698b2dd commit c1d15df

File tree

4 files changed

+217
-267
lines changed

4 files changed

+217
-267
lines changed

chia/_tests/core/server/test_rate_limits.py

Lines changed: 18 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import Any, cast
4+
from typing import Union
55

66
import pytest
77
from chia_rs.sized_ints import uint32
@@ -13,8 +13,7 @@
1313
from chia.protocols.outbound_message import make_msg
1414
from chia.protocols.protocol_message_types import ProtocolMessageTypes
1515
from chia.protocols.shared_protocol import Capability
16-
from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use
17-
from chia.server.rate_limit_numbers import rate_limits as rl_numbers
16+
from chia.server.rate_limit_numbers import RLSettings, Unlimited, get_rate_limits_to_use
1817
from chia.server.rate_limits import RateLimiter
1918
from chia.server.server import ChiaServer
2019
from chia.server.ws_connection import WSChiaConnection
@@ -67,39 +66,22 @@ async def test_limits_v2(incoming: bool, tx_msg: bool, limit_size: bool, monkeyp
6766
message_data = b"\0" * 1024
6867
msg_type = ProtocolMessageTypes.new_transaction
6968

70-
limits: dict[str, Any] = {}
69+
limits: dict[ProtocolMessageTypes, Union[RLSettings, Unlimited]]
7170

7271
if limit_size:
73-
limits.update(
74-
{
75-
# this is the rate limit across all (non-tx) messages
76-
"non_tx_freq": count * 2,
77-
# this is the byte size limit across all (non-tx) messages
78-
"non_tx_max_total_size": count * len(message_data),
79-
}
80-
)
72+
agg_limit = RLSettings(False, count * 2, len(message_data), count * len(message_data))
8173
else:
82-
limits.update(
83-
{
84-
# this is the rate limit across all (non-tx) messages
85-
"non_tx_freq": count,
86-
# this is the byte size limit across all (non-tx) messages
87-
"non_tx_max_total_size": count * 2 * len(message_data),
88-
}
89-
)
74+
agg_limit = RLSettings(False, count, len(message_data), count * 2 * len(message_data))
9075

9176
if limit_size:
92-
rate_limit = {msg_type: RLSettings(count * 2, 1024, count * len(message_data))}
77+
limits = {msg_type: RLSettings(not tx_msg, count * 2, len(message_data), count * len(message_data))}
9378
else:
94-
rate_limit = {msg_type: RLSettings(count, 1024, count * 2 * len(message_data))}
79+
limits = {msg_type: RLSettings(not tx_msg, count, len(message_data), count * 2 * len(message_data))}
9580

96-
if tx_msg:
97-
limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}})
98-
else:
99-
limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}})
100-
101-
def mock_get_limits(our_capabilities: list[Capability], peer_capabilities: list[Capability]) -> dict[str, Any]:
102-
return limits
81+
def mock_get_limits(
82+
our_capabilities: list[Capability], peer_capabilities: list[Capability]
83+
) -> tuple[dict[ProtocolMessageTypes, Union[RLSettings, Unlimited]], RLSettings]:
84+
return limits, agg_limit
10385

10486
import chia.server.rate_limits
10587

@@ -327,18 +309,18 @@ async def test_too_many_outgoing_messages() -> None:
327309
# Too many messages
328310
r = RateLimiter(incoming=False, get_time=lambda: 0)
329311
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
330-
non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"]
312+
_, agg_limit = get_rate_limits_to_use(rl_v2, rl_v2)
331313

332314
passed = 0
333315
blocked = 0
334-
for i in range(non_tx_freq):
316+
for i in range(agg_limit.frequency):
335317
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
336318
passed += 1
337319
else:
338320
blocked += 1
339321

340322
assert passed == 10
341-
assert blocked == non_tx_freq - passed
323+
assert blocked == agg_limit.frequency - passed
342324

343325
# ensure that *another* message type is not blocked because of this
344326

@@ -351,18 +333,18 @@ async def test_too_many_incoming_messages() -> None:
351333
# Too many messages
352334
r = RateLimiter(incoming=True, get_time=lambda: 0)
353335
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
354-
non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"]
336+
_, agg_limit = get_rate_limits_to_use(rl_v2, rl_v2)
355337

356338
passed = 0
357339
blocked = 0
358-
for i in range(non_tx_freq):
340+
for i in range(agg_limit.frequency):
359341
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
360342
passed += 1
361343
else:
362344
blocked += 1
363345

364346
assert passed == 10
365-
assert blocked == non_tx_freq - passed
347+
assert blocked == agg_limit.frequency - passed
366348

367349
# ensure that other message types *are* blocked because of this
368350

@@ -436,43 +418,13 @@ async def test_different_versions(
436418
# The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx",
437419
# which would mean the same rate limits are always used. This should not happen, since two nodes with V2
438420
# will use V2.
439-
total_tx_msg_count = len(
440-
get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities)["rate_limits_tx"]
441-
)
421+
total_tx_msg_count = len(get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities))
442422

443423
test_different_versions_results.append(total_tx_msg_count)
444424
if len(test_different_versions_results) >= 4:
445425
assert len(set(test_different_versions_results)) >= 2
446426

447427

448-
@pytest.mark.anyio
449-
async def test_compose() -> None:
450-
rl_1 = rl_numbers[1]
451-
rl_2 = rl_numbers[2]
452-
rl_1_rate_limits_other = cast(dict[ProtocolMessageTypes, RLSettings], rl_1["rate_limits_other"])
453-
rl_2_rate_limits_other = cast(dict[ProtocolMessageTypes, RLSettings], rl_2["rate_limits_other"])
454-
rl_1_rate_limits_tx = cast(dict[ProtocolMessageTypes, RLSettings], rl_1["rate_limits_tx"])
455-
rl_2_rate_limits_tx = cast(dict[ProtocolMessageTypes, RLSettings], rl_2["rate_limits_tx"])
456-
assert ProtocolMessageTypes.respond_children in rl_1_rate_limits_other
457-
assert ProtocolMessageTypes.respond_children not in rl_1_rate_limits_tx
458-
assert ProtocolMessageTypes.respond_children not in rl_2_rate_limits_other
459-
assert ProtocolMessageTypes.respond_children in rl_2_rate_limits_tx
460-
461-
assert ProtocolMessageTypes.request_block in rl_1_rate_limits_other
462-
assert ProtocolMessageTypes.request_block not in rl_1_rate_limits_tx
463-
assert ProtocolMessageTypes.request_block not in rl_2_rate_limits_other
464-
assert ProtocolMessageTypes.request_block not in rl_2_rate_limits_tx
465-
466-
comps = compose_rate_limits(rl_1, rl_2)
467-
# v2 limits are used if present
468-
assert ProtocolMessageTypes.respond_children not in comps["rate_limits_other"]
469-
assert ProtocolMessageTypes.respond_children in comps["rate_limits_tx"]
470-
471-
# Otherwise, fall back to v1
472-
assert ProtocolMessageTypes.request_block in rl_1_rate_limits_other
473-
assert ProtocolMessageTypes.request_block not in rl_1_rate_limits_tx
474-
475-
476428
@pytest.mark.anyio
477429
@pytest.mark.parametrize(
478430
"msg_type, size",

chia/_tests/util/test_network_protocol_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import inspect
55
from typing import Any, cast
66

7+
import pytest
8+
79
from chia.protocols import (
810
farmer_protocol,
911
full_node_protocol,
@@ -264,3 +266,19 @@ def test_missing_messages() -> None:
264266
assert types_in_module(shared_protocol) == shared_msgs, (
265267
f"message types were added or removed from shared_protocol. {STANDARD_ADVICE}"
266268
)
269+
270+
271+
@pytest.mark.parametrize("version", [1, 2])
272+
def test_rate_limits_complete(version: int) -> None:
273+
from chia.protocols.protocol_message_types import ProtocolMessageTypes
274+
from chia.server.rate_limit_numbers import rate_limits
275+
276+
if version == 1:
277+
composed = rate_limits[1]
278+
elif version == 2:
279+
composed = {
280+
**rate_limits[1],
281+
**rate_limits[2],
282+
}
283+
284+
assert set(composed.keys()) == set(ProtocolMessageTypes)

0 commit comments

Comments
 (0)