1
1
from __future__ import annotations
2
2
3
3
from dataclasses import dataclass
4
- from typing import Any , cast
4
+ from typing import Union
5
5
6
6
import pytest
7
7
from chia_rs .sized_ints import uint32
13
13
from chia .protocols .outbound_message import make_msg
14
14
from chia .protocols .protocol_message_types import ProtocolMessageTypes
15
15
from chia .protocols .shared_protocol import Capability
16
- from chia .server .rate_limit_numbers import RLSettings , compose_rate_limits , get_rate_limits_to_use
17
- from chia .server .rate_limit_numbers import rate_limits as rl_numbers
16
+ from chia .server .rate_limit_numbers import RLSettings , Unlimited , get_rate_limits_to_use
18
17
from chia .server .rate_limits import RateLimiter
19
18
from chia .server .server import ChiaServer
20
19
from chia .server .ws_connection import WSChiaConnection
@@ -67,39 +66,22 @@ async def test_limits_v2(incoming: bool, tx_msg: bool, limit_size: bool, monkeyp
67
66
message_data = b"\0 " * 1024
68
67
msg_type = ProtocolMessageTypes .new_transaction
69
68
70
- limits : dict [str , Any ] = {}
69
+ limits : dict [ProtocolMessageTypes , Union [ RLSettings , Unlimited ]]
71
70
72
71
if limit_size :
73
- limits .update (
74
- {
75
- # this is the rate limit across all (non-tx) messages
76
- "non_tx_freq" : count * 2 ,
77
- # this is the byte size limit across all (non-tx) messages
78
- "non_tx_max_total_size" : count * len (message_data ),
79
- }
80
- )
72
+ agg_limit = RLSettings (False , count * 2 , len (message_data ), count * len (message_data ))
81
73
else :
82
- limits .update (
83
- {
84
- # this is the rate limit across all (non-tx) messages
85
- "non_tx_freq" : count ,
86
- # this is the byte size limit across all (non-tx) messages
87
- "non_tx_max_total_size" : count * 2 * len (message_data ),
88
- }
89
- )
74
+ agg_limit = RLSettings (False , count , len (message_data ), count * 2 * len (message_data ))
90
75
91
76
if limit_size :
92
- rate_limit = {msg_type : RLSettings (count * 2 , 1024 , count * len (message_data ))}
77
+ limits = {msg_type : RLSettings (not tx_msg , count * 2 , len ( message_data ) , count * len (message_data ))}
93
78
else :
94
- rate_limit = {msg_type : RLSettings (count , 1024 , count * 2 * len (message_data ))}
79
+ limits = {msg_type : RLSettings (not tx_msg , count , len ( message_data ) , count * 2 * len (message_data ))}
95
80
96
- if tx_msg :
97
- limits .update ({"rate_limits_tx" : rate_limit , "rate_limits_other" : {}})
98
- else :
99
- limits .update ({"rate_limits_other" : rate_limit , "rate_limits_tx" : {}})
100
-
101
- def mock_get_limits (our_capabilities : list [Capability ], peer_capabilities : list [Capability ]) -> dict [str , Any ]:
102
- return limits
81
+ def mock_get_limits (
82
+ our_capabilities : list [Capability ], peer_capabilities : list [Capability ]
83
+ ) -> tuple [dict [ProtocolMessageTypes , Union [RLSettings , Unlimited ]], RLSettings ]:
84
+ return limits , agg_limit
103
85
104
86
import chia .server .rate_limits
105
87
@@ -327,18 +309,18 @@ async def test_too_many_outgoing_messages() -> None:
327
309
# Too many messages
328
310
r = RateLimiter (incoming = False , get_time = lambda : 0 )
329
311
new_peers_message = make_msg (ProtocolMessageTypes .respond_peers , bytes ([1 ]))
330
- non_tx_freq = get_rate_limits_to_use (rl_v2 , rl_v2 )[ "non_tx_freq" ]
312
+ _ , agg_limit = get_rate_limits_to_use (rl_v2 , rl_v2 )
331
313
332
314
passed = 0
333
315
blocked = 0
334
- for i in range (non_tx_freq ):
316
+ for i in range (agg_limit . frequency ):
335
317
if r .process_msg_and_check (new_peers_message , rl_v2 , rl_v2 ) is None :
336
318
passed += 1
337
319
else :
338
320
blocked += 1
339
321
340
322
assert passed == 10
341
- assert blocked == non_tx_freq - passed
323
+ assert blocked == agg_limit . frequency - passed
342
324
343
325
# ensure that *another* message type is not blocked because of this
344
326
@@ -351,18 +333,18 @@ async def test_too_many_incoming_messages() -> None:
351
333
# Too many messages
352
334
r = RateLimiter (incoming = True , get_time = lambda : 0 )
353
335
new_peers_message = make_msg (ProtocolMessageTypes .respond_peers , bytes ([1 ]))
354
- non_tx_freq = get_rate_limits_to_use (rl_v2 , rl_v2 )[ "non_tx_freq" ]
336
+ _ , agg_limit = get_rate_limits_to_use (rl_v2 , rl_v2 )
355
337
356
338
passed = 0
357
339
blocked = 0
358
- for i in range (non_tx_freq ):
340
+ for i in range (agg_limit . frequency ):
359
341
if r .process_msg_and_check (new_peers_message , rl_v2 , rl_v2 ) is None :
360
342
passed += 1
361
343
else :
362
344
blocked += 1
363
345
364
346
assert passed == 10
365
- assert blocked == non_tx_freq - passed
347
+ assert blocked == agg_limit . frequency - passed
366
348
367
349
# ensure that other message types *are* blocked because of this
368
350
@@ -436,43 +418,13 @@ async def test_different_versions(
436
418
# The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx",
437
419
# which would mean the same rate limits are always used. This should not happen, since two nodes with V2
438
420
# will use V2.
439
- total_tx_msg_count = len (
440
- get_rate_limits_to_use (a_con .local_capabilities , a_con .peer_capabilities )["rate_limits_tx" ]
441
- )
421
+ total_tx_msg_count = len (get_rate_limits_to_use (a_con .local_capabilities , a_con .peer_capabilities ))
442
422
443
423
test_different_versions_results .append (total_tx_msg_count )
444
424
if len (test_different_versions_results ) >= 4 :
445
425
assert len (set (test_different_versions_results )) >= 2
446
426
447
427
448
- @pytest .mark .anyio
449
- async def test_compose () -> None :
450
- rl_1 = rl_numbers [1 ]
451
- rl_2 = rl_numbers [2 ]
452
- rl_1_rate_limits_other = cast (dict [ProtocolMessageTypes , RLSettings ], rl_1 ["rate_limits_other" ])
453
- rl_2_rate_limits_other = cast (dict [ProtocolMessageTypes , RLSettings ], rl_2 ["rate_limits_other" ])
454
- rl_1_rate_limits_tx = cast (dict [ProtocolMessageTypes , RLSettings ], rl_1 ["rate_limits_tx" ])
455
- rl_2_rate_limits_tx = cast (dict [ProtocolMessageTypes , RLSettings ], rl_2 ["rate_limits_tx" ])
456
- assert ProtocolMessageTypes .respond_children in rl_1_rate_limits_other
457
- assert ProtocolMessageTypes .respond_children not in rl_1_rate_limits_tx
458
- assert ProtocolMessageTypes .respond_children not in rl_2_rate_limits_other
459
- assert ProtocolMessageTypes .respond_children in rl_2_rate_limits_tx
460
-
461
- assert ProtocolMessageTypes .request_block in rl_1_rate_limits_other
462
- assert ProtocolMessageTypes .request_block not in rl_1_rate_limits_tx
463
- assert ProtocolMessageTypes .request_block not in rl_2_rate_limits_other
464
- assert ProtocolMessageTypes .request_block not in rl_2_rate_limits_tx
465
-
466
- comps = compose_rate_limits (rl_1 , rl_2 )
467
- # v2 limits are used if present
468
- assert ProtocolMessageTypes .respond_children not in comps ["rate_limits_other" ]
469
- assert ProtocolMessageTypes .respond_children in comps ["rate_limits_tx" ]
470
-
471
- # Otherwise, fall back to v1
472
- assert ProtocolMessageTypes .request_block in rl_1_rate_limits_other
473
- assert ProtocolMessageTypes .request_block not in rl_1_rate_limits_tx
474
-
475
-
476
428
@pytest .mark .anyio
477
429
@pytest .mark .parametrize (
478
430
"msg_type, size" ,
0 commit comments