1
1
from __future__ import annotations
2
2
3
3
from dataclasses import dataclass
4
- from typing import Any , cast
4
+ from typing import Union
5
5
6
6
import pytest
7
7
from chia_rs .sized_ints import uint32
13
13
from chia .protocols .outbound_message import make_msg
14
14
from chia .protocols .protocol_message_types import ProtocolMessageTypes
15
15
from chia .protocols .shared_protocol import Capability
16
- from chia .server .rate_limit_numbers import RLSettings , compose_rate_limits , get_rate_limits_to_use
17
- from chia .server .rate_limit_numbers import rate_limits as rl_numbers
16
+ from chia .server .rate_limit_numbers import RLSettings , Unlimited , get_rate_limits_to_use
18
17
from chia .server .rate_limits import RateLimiter
19
18
from chia .server .server import ChiaServer
20
19
from chia .server .ws_connection import WSChiaConnection
@@ -66,39 +65,22 @@ async def test_limits_v2(incoming: bool, tx_msg: bool, limit_size: bool, monkeyp
66
65
message_data = b"\0 " * 1024
67
66
msg_type = ProtocolMessageTypes .new_transaction
68
67
69
- limits : dict [str , Any ] = {}
68
+ limits : dict [ProtocolMessageTypes , Union [ RLSettings , Unlimited ]]
70
69
71
70
if limit_size :
72
- limits .update (
73
- {
74
- # this is the rate limit across all (non-tx) messages
75
- "non_tx_freq" : count * 2 ,
76
- # this is the byte size limit across all (non-tx) messages
77
- "non_tx_max_total_size" : count * len (message_data ),
78
- }
79
- )
71
+ agg_limit = RLSettings (False , count * 2 , 1024 , count * len (message_data ))
80
72
else :
81
- limits .update (
82
- {
83
- # this is the rate limit across all (non-tx) messages
84
- "non_tx_freq" : count ,
85
- # this is the byte size limit across all (non-tx) messages
86
- "non_tx_max_total_size" : count * 2 * len (message_data ),
87
- }
88
- )
73
+ agg_limit = RLSettings (False , count , 1024 , count * 2 * len (message_data ))
89
74
90
75
if limit_size :
91
- rate_limit = {msg_type : RLSettings (count * 2 , 1024 , count * len (message_data ))}
76
+ limits = {msg_type : RLSettings (not tx_msg , count * 2 , 1024 , count * len (message_data ))}
92
77
else :
93
- rate_limit = {msg_type : RLSettings (count , 1024 , count * 2 * len (message_data ))}
78
+ limits = {msg_type : RLSettings (not tx_msg , count , 1024 , count * 2 * len (message_data ))}
94
79
95
- if tx_msg :
96
- limits .update ({"rate_limits_tx" : rate_limit , "rate_limits_other" : {}})
97
- else :
98
- limits .update ({"rate_limits_other" : rate_limit , "rate_limits_tx" : {}})
99
-
100
- def mock_get_limits (our_capabilities : list [Capability ], peer_capabilities : list [Capability ]) -> dict [str , Any ]:
101
- return limits
80
+ def mock_get_limits (
81
+ our_capabilities : list [Capability ], peer_capabilities : list [Capability ]
82
+ ) -> tuple [dict [ProtocolMessageTypes , Union [RLSettings , Unlimited ]], RLSettings ]:
83
+ return limits , agg_limit
102
84
103
85
import chia .server .rate_limits
104
86
@@ -326,18 +308,18 @@ async def test_too_many_outgoing_messages() -> None:
326
308
# Too many messages
327
309
r = RateLimiter (incoming = False , get_time = lambda : 0 )
328
310
new_peers_message = make_msg (ProtocolMessageTypes .respond_peers , bytes ([1 ]))
329
- non_tx_freq = get_rate_limits_to_use (rl_v2 , rl_v2 )[ "non_tx_freq" ]
311
+ _ , agg_limit = get_rate_limits_to_use (rl_v2 , rl_v2 )
330
312
331
313
passed = 0
332
314
blocked = 0
333
- for i in range (non_tx_freq ):
315
+ for i in range (agg_limit . frequency ):
334
316
if r .process_msg_and_check (new_peers_message , rl_v2 , rl_v2 ) is None :
335
317
passed += 1
336
318
else :
337
319
blocked += 1
338
320
339
321
assert passed == 10
340
- assert blocked == non_tx_freq - passed
322
+ assert blocked == agg_limit . frequency - passed
341
323
342
324
# ensure that *another* message type is not blocked because of this
343
325
@@ -350,18 +332,18 @@ async def test_too_many_incoming_messages() -> None:
350
332
# Too many messages
351
333
r = RateLimiter (incoming = True , get_time = lambda : 0 )
352
334
new_peers_message = make_msg (ProtocolMessageTypes .respond_peers , bytes ([1 ]))
353
- non_tx_freq = get_rate_limits_to_use (rl_v2 , rl_v2 )[ "non_tx_freq" ]
335
+ _ , agg_limit = get_rate_limits_to_use (rl_v2 , rl_v2 )
354
336
355
337
passed = 0
356
338
blocked = 0
357
- for i in range (non_tx_freq ):
339
+ for i in range (agg_limit . frequency ):
358
340
if r .process_msg_and_check (new_peers_message , rl_v2 , rl_v2 ) is None :
359
341
passed += 1
360
342
else :
361
343
blocked += 1
362
344
363
345
assert passed == 10
364
- assert blocked == non_tx_freq - passed
346
+ assert blocked == agg_limit . frequency - passed
365
347
366
348
# ensure that other message types *are* blocked because of this
367
349
@@ -435,43 +417,13 @@ async def test_different_versions(
435
417
# The following code checks whether all of the runs resulted in the same number of items in "rate_limits_tx",
436
418
# which would mean the same rate limits are always used. This should not happen, since two nodes with V2
437
419
# will use V2.
438
- total_tx_msg_count = len (
439
- get_rate_limits_to_use (a_con .local_capabilities , a_con .peer_capabilities )["rate_limits_tx" ]
440
- )
420
+ total_tx_msg_count = len (get_rate_limits_to_use (a_con .local_capabilities , a_con .peer_capabilities ))
441
421
442
422
test_different_versions_results .append (total_tx_msg_count )
443
423
if len (test_different_versions_results ) >= 4 :
444
424
assert len (set (test_different_versions_results )) >= 2
445
425
446
426
447
- @pytest .mark .anyio
448
- async def test_compose () -> None :
449
- rl_1 = rl_numbers [1 ]
450
- rl_2 = rl_numbers [2 ]
451
- rl_1_rate_limits_other = cast (dict [ProtocolMessageTypes , RLSettings ], rl_1 ["rate_limits_other" ])
452
- rl_2_rate_limits_other = cast (dict [ProtocolMessageTypes , RLSettings ], rl_2 ["rate_limits_other" ])
453
- rl_1_rate_limits_tx = cast (dict [ProtocolMessageTypes , RLSettings ], rl_1 ["rate_limits_tx" ])
454
- rl_2_rate_limits_tx = cast (dict [ProtocolMessageTypes , RLSettings ], rl_2 ["rate_limits_tx" ])
455
- assert ProtocolMessageTypes .respond_children in rl_1_rate_limits_other
456
- assert ProtocolMessageTypes .respond_children not in rl_1_rate_limits_tx
457
- assert ProtocolMessageTypes .respond_children not in rl_2_rate_limits_other
458
- assert ProtocolMessageTypes .respond_children in rl_2_rate_limits_tx
459
-
460
- assert ProtocolMessageTypes .request_block in rl_1_rate_limits_other
461
- assert ProtocolMessageTypes .request_block not in rl_1_rate_limits_tx
462
- assert ProtocolMessageTypes .request_block not in rl_2_rate_limits_other
463
- assert ProtocolMessageTypes .request_block not in rl_2_rate_limits_tx
464
-
465
- comps = compose_rate_limits (rl_1 , rl_2 )
466
- # v2 limits are used if present
467
- assert ProtocolMessageTypes .respond_children not in comps ["rate_limits_other" ]
468
- assert ProtocolMessageTypes .respond_children in comps ["rate_limits_tx" ]
469
-
470
- # Otherwise, fall back to v1
471
- assert ProtocolMessageTypes .request_block in rl_1_rate_limits_other
472
- assert ProtocolMessageTypes .request_block not in rl_1_rate_limits_tx
473
-
474
-
475
427
@pytest .mark .anyio
476
428
@pytest .mark .parametrize (
477
429
"msg_type, size" ,
0 commit comments