11from __future__ import annotations
22
33from dataclasses import dataclass
4- from typing import Any , cast
4+ from typing import Union
55
66import pytest
77from chia_rs .sized_ints import uint32
1313from chia .protocols .outbound_message import make_msg
1414from chia .protocols .protocol_message_types import ProtocolMessageTypes
1515from chia .protocols .shared_protocol import Capability
16- from chia .server .rate_limit_numbers import RLSettings , compose_rate_limits , get_rate_limits_to_use
17- from chia .server .rate_limit_numbers import rate_limits as rl_numbers
16+ from chia .server .rate_limit_numbers import RLSettings , Unlimited , get_rate_limits_to_use
1817from chia .server .rate_limits import RateLimiter
1918from chia .server .server import ChiaServer
2019from chia .server .ws_connection import WSChiaConnection
@@ -67,39 +66,22 @@ async def test_limits_v2(incoming: bool, tx_msg: bool, limit_size: bool, monkeyp
6766 message_data = b"\0 " * 1024
6867 msg_type = ProtocolMessageTypes .new_transaction
6968
70- limits : dict [str , Any ] = {}
69+ limits : dict [ProtocolMessageTypes , Union [ RLSettings , Unlimited ]]
7170
7271 if limit_size :
73- limits .update (
74- {
75- # this is the rate limit across all (non-tx) messages
76- "non_tx_freq" : count * 2 ,
77- # this is the byte size limit across all (non-tx) messages
78- "non_tx_max_total_size" : count * len (message_data ),
79- }
80- )
72+ agg_limit = RLSettings (False , count * 2 , len (message_data ), count * len (message_data ))
8173 else :
82- limits .update (
83- {
84- # this is the rate limit across all (non-tx) messages
85- "non_tx_freq" : count ,
86- # this is the byte size limit across all (non-tx) messages
87- "non_tx_max_total_size" : count * 2 * len (message_data ),
88- }
89- )
74+ agg_limit = RLSettings (False , count , len (message_data ), count * 2 * len (message_data ))
9075
9176 if limit_size :
92- rate_limit = {msg_type : RLSettings (count * 2 , 1024 , count * len (message_data ))}
77+ limits = {msg_type : RLSettings (not tx_msg , count * 2 , len ( message_data ) , count * len (message_data ))}
9378 else :
94- rate_limit = {msg_type : RLSettings (count , 1024 , count * 2 * len (message_data ))}
79+ limits = {msg_type : RLSettings (not tx_msg , count , len ( message_data ) , count * 2 * len (message_data ))}
9580
96- if tx_msg :
97- limits .update ({"rate_limits_tx" : rate_limit , "rate_limits_other" : {}})
98- else :
99- limits .update ({"rate_limits_other" : rate_limit , "rate_limits_tx" : {}})
100-
101- def mock_get_limits (our_capabilities : list [Capability ], peer_capabilities : list [Capability ]) -> dict [str , Any ]:
102- return limits
81+ def mock_get_limits (
82+ our_capabilities : list [Capability ], peer_capabilities : list [Capability ]
83+ ) -> tuple [dict [ProtocolMessageTypes , Union [RLSettings , Unlimited ]], RLSettings ]:
84+ return limits , agg_limit
10385
10486 import chia .server .rate_limits
10587
@@ -327,18 +309,18 @@ async def test_too_many_outgoing_messages() -> None:
327309 # Too many messages
328310 r = RateLimiter (incoming = False , get_time = lambda : 0 )
329311 new_peers_message = make_msg (ProtocolMessageTypes .respond_peers , bytes ([1 ]))
330- non_tx_freq = get_rate_limits_to_use (rl_v2 , rl_v2 )[ "non_tx_freq" ]
312+ _ , agg_limit = get_rate_limits_to_use (rl_v2 , rl_v2 )
331313
332314 passed = 0
333315 blocked = 0
334- for i in range (non_tx_freq ):
316+ for i in range (agg_limit . frequency ):
335317 if r .process_msg_and_check (new_peers_message , rl_v2 , rl_v2 ) is None :
336318 passed += 1
337319 else :
338320 blocked += 1
339321
340322 assert passed == 10
341- assert blocked == non_tx_freq - passed
323+ assert blocked == agg_limit . frequency - passed
342324
343325 # ensure that *another* message type is not blocked because of this
344326
@@ -351,18 +333,18 @@ async def test_too_many_incoming_messages() -> None:
351333 # Too many messages
352334 r = RateLimiter (incoming = True , get_time = lambda : 0 )
353335 new_peers_message = make_msg (ProtocolMessageTypes .respond_peers , bytes ([1 ]))
354- non_tx_freq = get_rate_limits_to_use (rl_v2 , rl_v2 )[ "non_tx_freq" ]
336+ _ , agg_limit = get_rate_limits_to_use (rl_v2 , rl_v2 )
355337
356338 passed = 0
357339 blocked = 0
358- for i in range (non_tx_freq ):
340+ for i in range (agg_limit . frequency ):
359341 if r .process_msg_and_check (new_peers_message , rl_v2 , rl_v2 ) is None :
360342 passed += 1
361343 else :
362344 blocked += 1
363345
364346 assert passed == 10
365- assert blocked == non_tx_freq - passed
347+ assert blocked == agg_limit . frequency - passed
366348
367349 # ensure that other message types *are* blocked because of this
368350
@@ -436,43 +418,13 @@ async def test_different_versions(
436418 # The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx",
437419 # which would mean the same rate limits are always used. This should not happen, since two nodes with V2
438420 # will use V2.
439- total_tx_msg_count = len (
440- get_rate_limits_to_use (a_con .local_capabilities , a_con .peer_capabilities )["rate_limits_tx" ]
441- )
421+ total_tx_msg_count = len (get_rate_limits_to_use (a_con .local_capabilities , a_con .peer_capabilities ))
442422
443423 test_different_versions_results .append (total_tx_msg_count )
444424 if len (test_different_versions_results ) >= 4 :
445425 assert len (set (test_different_versions_results )) >= 2
446426
447427
448- @pytest .mark .anyio
449- async def test_compose () -> None :
450- rl_1 = rl_numbers [1 ]
451- rl_2 = rl_numbers [2 ]
452- rl_1_rate_limits_other = cast (dict [ProtocolMessageTypes , RLSettings ], rl_1 ["rate_limits_other" ])
453- rl_2_rate_limits_other = cast (dict [ProtocolMessageTypes , RLSettings ], rl_2 ["rate_limits_other" ])
454- rl_1_rate_limits_tx = cast (dict [ProtocolMessageTypes , RLSettings ], rl_1 ["rate_limits_tx" ])
455- rl_2_rate_limits_tx = cast (dict [ProtocolMessageTypes , RLSettings ], rl_2 ["rate_limits_tx" ])
456- assert ProtocolMessageTypes .respond_children in rl_1_rate_limits_other
457- assert ProtocolMessageTypes .respond_children not in rl_1_rate_limits_tx
458- assert ProtocolMessageTypes .respond_children not in rl_2_rate_limits_other
459- assert ProtocolMessageTypes .respond_children in rl_2_rate_limits_tx
460-
461- assert ProtocolMessageTypes .request_block in rl_1_rate_limits_other
462- assert ProtocolMessageTypes .request_block not in rl_1_rate_limits_tx
463- assert ProtocolMessageTypes .request_block not in rl_2_rate_limits_other
464- assert ProtocolMessageTypes .request_block not in rl_2_rate_limits_tx
465-
466- comps = compose_rate_limits (rl_1 , rl_2 )
467- # v2 limits are used if present
468- assert ProtocolMessageTypes .respond_children not in comps ["rate_limits_other" ]
469- assert ProtocolMessageTypes .respond_children in comps ["rate_limits_tx" ]
470-
471- # Otherwise, fall back to v1
472- assert ProtocolMessageTypes .request_block in rl_1_rate_limits_other
473- assert ProtocolMessageTypes .request_block not in rl_1_rate_limits_tx
474-
475-
476428@pytest .mark .anyio
477429@pytest .mark .parametrize (
478430 "msg_type, size" ,
0 commit comments