1
1
from __future__ import annotations
2
2
3
3
from dataclasses import dataclass
4
- from typing import Any
4
+ from typing import Union
5
5
6
6
import pytest
7
7
from chia_rs .sized_ints import uint32
13
13
from chia .protocols .outbound_message import make_msg
14
14
from chia .protocols .protocol_message_types import ProtocolMessageTypes
15
15
from chia .protocols .shared_protocol import Capability
16
- from chia .server .rate_limit_numbers import RLSettings , compose_rate_limits , get_rate_limits_to_use
17
- from chia .server .rate_limit_numbers import rate_limits as rl_numbers
16
+ from chia .server .rate_limit_numbers import RLSettings , Unlimited , get_rate_limits_to_use
18
17
from chia .server .rate_limits import RateLimiter
19
18
from chia .server .server import ChiaServer
20
19
from chia .server .ws_connection import WSChiaConnection
@@ -65,39 +64,20 @@ async def test_limits_v2(incoming: bool, tx_msg: bool, limit_size: bool, monkeyp
65
64
message_data = b"\0 " * 1024
66
65
msg_type = ProtocolMessageTypes .new_transaction
67
66
68
- limits : dict [str , Any ] = {}
67
+ limits : dict [ProtocolMessageTypes , Union [ RLSettings , Unlimited ]]
69
68
70
69
if limit_size :
71
- limits .update (
72
- {
73
- # this is the rate limit across all (non-tx) messages
74
- "non_tx_freq" : count * 2 ,
75
- # this is the byte size limit across all (non-tx) messages
76
- "non_tx_max_total_size" : count * len (message_data ),
77
- }
78
- )
70
+ agg_limit = RLSettings (False , count * 2 , 1024 , count * len (message_data ))
79
71
else :
80
- limits .update (
81
- {
82
- # this is the rate limit across all (non-tx) messages
83
- "non_tx_freq" : count ,
84
- # this is the byte size limit across all (non-tx) messages
85
- "non_tx_max_total_size" : count * 2 * len (message_data ),
86
- }
87
- )
72
+ agg_limit = RLSettings (False , count , 1024 , count * 2 * len (message_data ))
88
73
89
74
if limit_size :
90
- rate_limit = {msg_type : RLSettings (count * 2 , 1024 , count * len (message_data ))}
75
+ limits = {msg_type : RLSettings (not tx_msg , count * 2 , 1024 , count * len (message_data ))}
91
76
else :
92
- rate_limit = {msg_type : RLSettings (count , 1024 , count * 2 * len (message_data ))}
77
+ limits = {msg_type : RLSettings (not tx_msg , count , 1024 , count * 2 * len (message_data ))}
93
78
94
- if tx_msg :
95
- limits .update ({"rate_limits_tx" : rate_limit , "rate_limits_other" : {}})
96
- else :
97
- limits .update ({"rate_limits_other" : rate_limit , "rate_limits_tx" : {}})
98
-
99
- def mock_get_limits (* args , ** kwargs ) -> dict [str , Any ]:
100
- return limits
79
+ def mock_get_limits (* args , ** kwargs ) -> tuple [dict [ProtocolMessageTypes , Union [RLSettings , Unlimited ]], RLSettings ]:
80
+ return limits , agg_limit
101
81
102
82
import chia .server .rate_limits
103
83
@@ -325,18 +305,18 @@ async def test_too_many_outgoing_messages():
325
305
# Too many messages
326
306
r = RateLimiter (incoming = False , get_time = lambda : 0 )
327
307
new_peers_message = make_msg (ProtocolMessageTypes .respond_peers , bytes ([1 ]))
328
- non_tx_freq = get_rate_limits_to_use (rl_v2 , rl_v2 )[ "non_tx_freq" ]
308
+ _ , agg_limit = get_rate_limits_to_use (rl_v2 , rl_v2 )
329
309
330
310
passed = 0
331
311
blocked = 0
332
- for i in range (non_tx_freq ):
312
+ for i in range (agg_limit . frequency ):
333
313
if r .process_msg_and_check (new_peers_message , rl_v2 , rl_v2 ) is None :
334
314
passed += 1
335
315
else :
336
316
blocked += 1
337
317
338
318
assert passed == 10
339
- assert blocked == non_tx_freq - passed
319
+ assert blocked == agg_limit . frequency - passed
340
320
341
321
# ensure that *another* message type is not blocked because of this
342
322
@@ -349,18 +329,18 @@ async def test_too_many_incoming_messages():
349
329
# Too many messages
350
330
r = RateLimiter (incoming = True , get_time = lambda : 0 )
351
331
new_peers_message = make_msg (ProtocolMessageTypes .respond_peers , bytes ([1 ]))
352
- non_tx_freq = get_rate_limits_to_use (rl_v2 , rl_v2 )[ "non_tx_freq" ]
332
+ _ , agg_limit = get_rate_limits_to_use (rl_v2 , rl_v2 )
353
333
354
334
passed = 0
355
335
blocked = 0
356
- for i in range (non_tx_freq ):
336
+ for i in range (agg_limit . frequency ):
357
337
if r .process_msg_and_check (new_peers_message , rl_v2 , rl_v2 ) is None :
358
338
passed += 1
359
339
else :
360
340
blocked += 1
361
341
362
342
assert passed == 10
363
- assert blocked == non_tx_freq - passed
343
+ assert blocked == agg_limit . frequency - passed
364
344
365
345
# ensure that other message types *are* blocked because of this
366
346
@@ -432,39 +412,13 @@ async def test_different_versions(node_with_params, node_with_params_b, self_hos
432
412
# The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx",
433
413
# which would mean the same rate limits are always used. This should not happen, since two nodes with V2
434
414
# will use V2.
435
- total_tx_msg_count = len (
436
- get_rate_limits_to_use (a_con .local_capabilities , a_con .peer_capabilities )["rate_limits_tx" ]
437
- )
415
+ total_tx_msg_count = len (get_rate_limits_to_use (a_con .local_capabilities , a_con .peer_capabilities ))
438
416
439
417
test_different_versions_results .append (total_tx_msg_count )
440
418
if len (test_different_versions_results ) >= 4 :
441
419
assert len (set (test_different_versions_results )) >= 2
442
420
443
421
444
- @pytest .mark .anyio
445
- async def test_compose ():
446
- rl_1 = rl_numbers [1 ]
447
- rl_2 = rl_numbers [2 ]
448
- assert ProtocolMessageTypes .respond_children in rl_1 ["rate_limits_other" ]
449
- assert ProtocolMessageTypes .respond_children not in rl_1 ["rate_limits_tx" ]
450
- assert ProtocolMessageTypes .respond_children not in rl_2 ["rate_limits_other" ]
451
- assert ProtocolMessageTypes .respond_children in rl_2 ["rate_limits_tx" ]
452
-
453
- assert ProtocolMessageTypes .request_block in rl_1 ["rate_limits_other" ]
454
- assert ProtocolMessageTypes .request_block not in rl_1 ["rate_limits_tx" ]
455
- assert ProtocolMessageTypes .request_block not in rl_2 ["rate_limits_other" ]
456
- assert ProtocolMessageTypes .request_block not in rl_2 ["rate_limits_tx" ]
457
-
458
- comps = compose_rate_limits (rl_1 , rl_2 )
459
- # v2 limits are used if present
460
- assert ProtocolMessageTypes .respond_children not in comps ["rate_limits_other" ]
461
- assert ProtocolMessageTypes .respond_children in comps ["rate_limits_tx" ]
462
-
463
- # Otherwise, fall back to v1
464
- assert ProtocolMessageTypes .request_block in rl_1 ["rate_limits_other" ]
465
- assert ProtocolMessageTypes .request_block not in rl_1 ["rate_limits_tx" ]
466
-
467
-
468
422
@pytest .mark .anyio
469
423
@pytest .mark .parametrize (
470
424
"msg_type, size" ,
0 commit comments