Skip to content

Commit 17886b6

Browse files
committed
mock timer in rate limit test to make them deterministic
1 parent 6f600af commit 17886b6

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

chia/_tests/core/server/test_rate_limits.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import asyncio
4-
53
import pytest
64
from chia_rs.sized_ints import uint32
75

@@ -25,6 +23,28 @@
2523
test_different_versions_results: list[int] = []
2624

2725

26+
def sleep() -> float:
27+
return 0.0
28+
29+
30+
@pytest.fixture
31+
def mock_timer(monkeypatch):
32+
current_time = 1000.0
33+
34+
def mock_monotonic(*args, **kwargs) -> float:
35+
nonlocal current_time
36+
return current_time
37+
38+
def mock_sleep(*args, **kwargs) -> None:
39+
nonlocal current_time
40+
current_time += int(args[0])
41+
42+
import chia
43+
44+
monkeypatch.setattr(chia.server.rate_limits, "get_time", mock_monotonic)
45+
monkeypatch.setattr(chia._tests.core.server.test_rate_limits, "sleep", mock_sleep)
46+
47+
2848
class TestRateLimits:
2949
@pytest.mark.anyio
3050
async def test_get_rate_limits_to_use(self):
@@ -33,7 +53,7 @@ async def test_get_rate_limits_to_use(self):
3353
assert get_rate_limits_to_use(rl_v1, rl_v1) == get_rate_limits_to_use(rl_v1, rl_v2)
3454

3555
@pytest.mark.anyio
36-
async def test_too_many_messages(self):
56+
async def test_too_many_messages(self, mock_timer):
3757
# Too many messages
3858
r = RateLimiter(incoming=True)
3959
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
@@ -61,7 +81,7 @@ async def test_too_many_messages(self):
6181
assert saw_disconnect
6282

6383
@pytest.mark.anyio
64-
async def test_large_message(self):
84+
async def test_large_message(self, mock_timer):
6585
# Large tx
6686
small_tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
6787
large_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 3 * 1024 * 1024))
@@ -81,7 +101,7 @@ async def test_large_message(self):
81101
assert r.process_msg_and_check(large_blocks_message, rl_v2, rl_v2) is not None
82102

83103
@pytest.mark.anyio
84-
async def test_too_much_data(self):
104+
async def test_too_much_data(self, mock_timer):
85105
# Too much data
86106
r = RateLimiter(incoming=True)
87107
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
@@ -108,7 +128,7 @@ async def test_too_much_data(self):
108128
assert saw_disconnect
109129

110130
@pytest.mark.anyio
111-
async def test_non_tx_aggregate_limits(self):
131+
async def test_non_tx_aggregate_limits(self, mock_timer):
112132
# Frequency limits
113133
r = RateLimiter(incoming=True)
114134
message_1 = make_msg(ProtocolMessageTypes.coin_state_update, bytes([1] * 32))
@@ -144,7 +164,7 @@ async def test_non_tx_aggregate_limits(self):
144164
assert saw_disconnect
145165

146166
@pytest.mark.anyio
147-
async def test_periodic_reset(self):
167+
async def test_periodic_reset(self, mock_timer):
148168
r = RateLimiter(True, 5)
149169
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
150170
for i in range(10):
@@ -157,7 +177,7 @@ async def test_periodic_reset(self):
157177
saw_disconnect = True
158178
assert saw_disconnect
159179
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is not None
160-
await asyncio.sleep(6)
180+
sleep(6)
161181
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None
162182

163183
# Counts reset also
@@ -172,11 +192,11 @@ async def test_periodic_reset(self):
172192
if response is not None:
173193
saw_disconnect = True
174194
assert saw_disconnect
175-
await asyncio.sleep(6)
195+
sleep(6)
176196
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None
177197

178198
@pytest.mark.anyio
179-
async def test_percentage_limits(self):
199+
async def test_percentage_limits(self, mock_timer):
180200
r = RateLimiter(True, 60, 40)
181201
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
182202
for i in range(50):
@@ -235,7 +255,7 @@ async def test_percentage_limits(self):
235255
assert saw_disconnect
236256

237257
@pytest.mark.anyio
238-
async def test_too_many_outgoing_messages(self):
258+
async def test_too_many_outgoing_messages(self, mock_timer):
239259
# Too many messages
240260
r = RateLimiter(incoming=False)
241261
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
@@ -258,7 +278,7 @@ async def test_too_many_outgoing_messages(self):
258278
assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2) is None
259279

260280
@pytest.mark.anyio
261-
async def test_too_many_incoming_messages(self):
281+
async def test_too_many_incoming_messages(self, mock_timer):
262282
# Too many messages
263283
r = RateLimiter(incoming=True)
264284
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
@@ -353,7 +373,7 @@ async def test_different_versions(self, node_with_params, node_with_params_b, se
353373
assert len(set(test_different_versions_results)) >= 2
354374

355375
@pytest.mark.anyio
356-
async def test_compose(self):
376+
async def test_compose(self, mock_timer):
357377
rl_1 = rl_numbers[1]
358378
rl_2 = rl_numbers[2]
359379
assert ProtocolMessageTypes.respond_children in rl_1["rate_limits_other"]

chia/server/rate_limits.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
log = logging.getLogger(__name__)
1515

1616

17+
def get_time() -> float:
18+
return monotonic()
19+
20+
1721
# TODO: only full node disconnects based on rate limits
1822
class RateLimiter:
1923
incoming: bool
@@ -35,7 +39,7 @@ def __init__(self, incoming: bool, reset_seconds: int = 60, percentage_of_limit:
3539
"""
3640
self.incoming = incoming
3741
self.reset_seconds = reset_seconds
38-
self.current_slot = int(monotonic() // reset_seconds)
42+
self.current_slot = int(get_time() // reset_seconds)
3943
self.message_counts = Counter()
4044
self.message_cumulative_sizes = Counter()
4145
self.percentage_of_limit = percentage_of_limit
@@ -51,7 +55,7 @@ def process_msg_and_check(
5155
hit and the message is good to be sent or received.
5256
"""
5357

54-
current_slot = int(monotonic() // self.reset_seconds)
58+
current_slot = int(get_time() // self.reset_seconds)
5559
if current_slot != self.current_slot:
5660
self.current_slot = current_slot
5761
self.message_counts = Counter()

0 commit comments

Comments
 (0)