Skip to content

Commit 11560f5

Browse files
authored
TODO: throttle on async validators (#755)
* fixed todo: throttle on async validators * added test: validate message respects concurrency limit * added newsfragment * added configurable validator semaphore in the PubSub constructor * added the concurrency-checker in the original test-validate-msg test case * separate out a _run_async_validator function * remove redundant run_async_validator
1 parent 3507531 commit 11560f5

File tree

3 files changed

+77
-14
lines changed

3 files changed

+77
-14
lines changed

libp2p/pubsub/pubsub.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,17 @@ class TopicValidator(NamedTuple):
102102
is_async: bool
103103

104104

105+
MAX_CONCURRENT_VALIDATORS = 10
106+
107+
105108
class Pubsub(Service, IPubsub):
106109
host: IHost
107110

108111
router: IPubsubRouter
109112

110113
peer_receive_channel: trio.MemoryReceiveChannel[ID]
111114
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
115+
_validator_semaphore: trio.Semaphore
112116

113117
seen_messages: LastSeenCache
114118

@@ -143,6 +147,7 @@ def __init__(
143147
msg_id_constructor: Callable[
144148
[rpc_pb2.Message], bytes
145149
] = get_peer_and_seqno_msg_id,
150+
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
146151
) -> None:
147152
"""
148153
Construct a new Pubsub object, which is responsible for handling all
@@ -168,6 +173,7 @@ def __init__(
168173
# Therefore, we can only close from the receive side.
169174
self.peer_receive_channel = peer_receive
170175
self.dead_peer_receive_channel = dead_peer_receive
176+
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
171177
# Register a notifee
172178
self.host.get_network().register_notifee(
173179
PubsubNotifee(peer_send, dead_peer_send)
@@ -657,7 +663,11 @@ async def publish(self, topic_id: str | list[str], data: bytes) -> None:
657663

658664
logger.debug("successfully published message %s", msg)
659665

660-
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
666+
async def validate_msg(
667+
self,
668+
msg_forwarder: ID,
669+
msg: rpc_pb2.Message,
670+
) -> None:
661671
"""
662672
Validate the received message.
663673
@@ -680,23 +690,34 @@ async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
680690
if not validator(msg_forwarder, msg):
681691
raise ValidationError(f"Validation failed for msg={msg}")
682692

683-
# TODO: Implement throttle on async validators
684-
685693
if len(async_topic_validators) > 0:
686694
# Appends to lists are thread safe in CPython
687-
results = []
688-
689-
async def run_async_validator(func: AsyncValidatorFn) -> None:
690-
result = await func(msg_forwarder, msg)
691-
results.append(result)
695+
results: list[bool] = []
692696

693697
async with trio.open_nursery() as nursery:
694698
for async_validator in async_topic_validators:
695-
nursery.start_soon(run_async_validator, async_validator)
699+
nursery.start_soon(
700+
self._run_async_validator,
701+
async_validator,
702+
msg_forwarder,
703+
msg,
704+
results,
705+
)
696706

697707
if not all(results):
698708
raise ValidationError(f"Validation failed for msg={msg}")
699709

710+
async def _run_async_validator(
711+
self,
712+
func: AsyncValidatorFn,
713+
msg_forwarder: ID,
714+
msg: rpc_pb2.Message,
715+
results: list[bool],
716+
) -> None:
717+
async with self._validator_semaphore:
718+
result = await func(msg_forwarder, msg)
719+
results.append(result)
720+
700721
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
701722
"""
702723
Push a pubsub message to others.

newsfragments/755.performance.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added throttling for async topic validators in validate_msg, enforcing a
2+
concurrency limit to prevent resource exhaustion under heavy load.

tests/core/pubsub/test_pubsub.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from typing import (
66
NamedTuple,
77
)
8+
from unittest.mock import patch
89

910
import pytest
1011
import trio
1112

13+
from libp2p.custom_types import AsyncValidatorFn
1214
from libp2p.exceptions import (
1315
ValidationError,
1416
)
@@ -243,7 +245,37 @@ async def async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
243245
((False, True), (True, False), (True, True)),
244246
)
245247
@pytest.mark.trio
246-
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
248+
async def test_validate_msg_with_throttle_condition(
249+
is_topic_1_val_passed, is_topic_2_val_passed
250+
):
251+
CONCURRENCY_LIMIT = 10
252+
253+
state = {
254+
"concurrency_counter": 0,
255+
"max_observed": 0,
256+
}
257+
lock = trio.Lock()
258+
259+
async def mock_run_async_validator(
260+
self,
261+
func: AsyncValidatorFn,
262+
msg_forwarder: ID,
263+
msg: rpc_pb2.Message,
264+
results: list[bool],
265+
) -> None:
266+
async with self._validator_semaphore:
267+
async with lock:
268+
state["concurrency_counter"] += 1
269+
if state["concurrency_counter"] > state["max_observed"]:
270+
state["max_observed"] = state["concurrency_counter"]
271+
272+
try:
273+
result = await func(msg_forwarder, msg)
274+
results.append(result)
275+
finally:
276+
async with lock:
277+
state["concurrency_counter"] -= 1
278+
247279
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
248280

249281
def passed_sync_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
@@ -280,11 +312,19 @@ async def failed_async_validator(peer_id: ID, msg: rpc_pb2.Message) -> bool:
280312
seqno=b"\x00" * 8,
281313
)
282314

283-
if is_topic_1_val_passed and is_topic_2_val_passed:
284-
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
285-
else:
286-
with pytest.raises(ValidationError):
315+
with patch(
316+
"libp2p.pubsub.pubsub.Pubsub._run_async_validator",
317+
new=mock_run_async_validator,
318+
):
319+
if is_topic_1_val_passed and is_topic_2_val_passed:
287320
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
321+
else:
322+
with pytest.raises(ValidationError):
323+
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
324+
325+
assert state["max_observed"] <= CONCURRENCY_LIMIT, (
326+
f"Max concurrency observed: {state['max_observed']}"
327+
)
288328

289329

290330
@pytest.mark.trio

0 commit comments

Comments
 (0)