Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 18 additions & 66 deletions chia/_tests/core/server/test_rate_limits.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, cast
from typing import Union

import pytest
from chia_rs.sized_ints import uint32
Expand All @@ -13,8 +13,7 @@
from chia.protocols.outbound_message import make_msg
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.shared_protocol import Capability
from chia.server.rate_limit_numbers import RLSettings, compose_rate_limits, get_rate_limits_to_use
from chia.server.rate_limit_numbers import rate_limits as rl_numbers
from chia.server.rate_limit_numbers import RLSettings, Unlimited, get_rate_limits_to_use
from chia.server.rate_limits import RateLimiter
from chia.server.server import ChiaServer
from chia.server.ws_connection import WSChiaConnection
Expand Down Expand Up @@ -66,39 +65,22 @@ async def test_limits_v2(incoming: bool, tx_msg: bool, limit_size: bool, monkeyp
message_data = b"\0" * 1024
msg_type = ProtocolMessageTypes.new_transaction

limits: dict[str, Any] = {}
limits: dict[ProtocolMessageTypes, Union[RLSettings, Unlimited]]

if limit_size:
limits.update(
{
# this is the rate limit across all (non-tx) messages
"non_tx_freq": count * 2,
# this is the byte size limit across all (non-tx) messages
"non_tx_max_total_size": count * len(message_data),
}
)
agg_limit = RLSettings(False, count * 2, 1024, count * len(message_data))
else:
limits.update(
{
# this is the rate limit across all (non-tx) messages
"non_tx_freq": count,
# this is the byte size limit across all (non-tx) messages
"non_tx_max_total_size": count * 2 * len(message_data),
}
)
agg_limit = RLSettings(False, count, 1024, count * 2 * len(message_data))

if limit_size:
rate_limit = {msg_type: RLSettings(count * 2, 1024, count * len(message_data))}
limits = {msg_type: RLSettings(not tx_msg, count * 2, 1024, count * len(message_data))}
else:
rate_limit = {msg_type: RLSettings(count, 1024, count * 2 * len(message_data))}
limits = {msg_type: RLSettings(not tx_msg, count, 1024, count * 2 * len(message_data))}

if tx_msg:
limits.update({"rate_limits_tx": rate_limit, "rate_limits_other": {}})
else:
limits.update({"rate_limits_other": rate_limit, "rate_limits_tx": {}})

def mock_get_limits(our_capabilities: list[Capability], peer_capabilities: list[Capability]) -> dict[str, Any]:
return limits
def mock_get_limits(
our_capabilities: list[Capability], peer_capabilities: list[Capability]
) -> tuple[dict[ProtocolMessageTypes, Union[RLSettings, Unlimited]], RLSettings]:
return limits, agg_limit

import chia.server.rate_limits

Expand Down Expand Up @@ -326,18 +308,18 @@ async def test_too_many_outgoing_messages() -> None:
# Too many messages
r = RateLimiter(incoming=False, get_time=lambda: 0)
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"]
_, agg_limit = get_rate_limits_to_use(rl_v2, rl_v2)

passed = 0
blocked = 0
for i in range(non_tx_freq):
for i in range(agg_limit.frequency):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
passed += 1
else:
blocked += 1

assert passed == 10
assert blocked == non_tx_freq - passed
assert blocked == agg_limit.frequency - passed

# ensure that *another* message type is not blocked because of this

Expand All @@ -350,18 +332,18 @@ async def test_too_many_incoming_messages() -> None:
# Too many messages
r = RateLimiter(incoming=True, get_time=lambda: 0)
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
non_tx_freq = get_rate_limits_to_use(rl_v2, rl_v2)["non_tx_freq"]
_, agg_limit = get_rate_limits_to_use(rl_v2, rl_v2)

passed = 0
blocked = 0
for i in range(non_tx_freq):
for i in range(agg_limit.frequency):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
passed += 1
else:
blocked += 1

assert passed == 10
assert blocked == non_tx_freq - passed
assert blocked == agg_limit.frequency - passed

# ensure that other message types *are* blocked because of this

Expand Down Expand Up @@ -435,43 +417,13 @@ async def test_different_versions(
# The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx",
# which would mean the same rate limits are always used. This should not happen, since two nodes with V2
# will use V2.
total_tx_msg_count = len(
get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities)["rate_limits_tx"]
)
total_tx_msg_count = len(get_rate_limits_to_use(a_con.local_capabilities, a_con.peer_capabilities))

test_different_versions_results.append(total_tx_msg_count)
if len(test_different_versions_results) >= 4:
assert len(set(test_different_versions_results)) >= 2


@pytest.mark.anyio
async def test_compose() -> None:
rl_1 = rl_numbers[1]
rl_2 = rl_numbers[2]
rl_1_rate_limits_other = cast(dict[ProtocolMessageTypes, RLSettings], rl_1["rate_limits_other"])
rl_2_rate_limits_other = cast(dict[ProtocolMessageTypes, RLSettings], rl_2["rate_limits_other"])
rl_1_rate_limits_tx = cast(dict[ProtocolMessageTypes, RLSettings], rl_1["rate_limits_tx"])
rl_2_rate_limits_tx = cast(dict[ProtocolMessageTypes, RLSettings], rl_2["rate_limits_tx"])
assert ProtocolMessageTypes.respond_children in rl_1_rate_limits_other
assert ProtocolMessageTypes.respond_children not in rl_1_rate_limits_tx
assert ProtocolMessageTypes.respond_children not in rl_2_rate_limits_other
assert ProtocolMessageTypes.respond_children in rl_2_rate_limits_tx

assert ProtocolMessageTypes.request_block in rl_1_rate_limits_other
assert ProtocolMessageTypes.request_block not in rl_1_rate_limits_tx
assert ProtocolMessageTypes.request_block not in rl_2_rate_limits_other
assert ProtocolMessageTypes.request_block not in rl_2_rate_limits_tx

comps = compose_rate_limits(rl_1, rl_2)
# v2 limits are used if present
assert ProtocolMessageTypes.respond_children not in comps["rate_limits_other"]
assert ProtocolMessageTypes.respond_children in comps["rate_limits_tx"]

# Otherwise, fall back to v1
assert ProtocolMessageTypes.request_block in rl_1_rate_limits_other
assert ProtocolMessageTypes.request_block not in rl_1_rate_limits_tx


@pytest.mark.anyio
@pytest.mark.parametrize(
"msg_type, size",
Expand Down
Loading
Loading