Skip to content

Commit 355227d

Browse files
committed
PeerSubscribers now specify what messages they are interested in
1 parent 888f105 commit 355227d

File tree

9 files changed

+186
-55
lines changed

9 files changed

+186
-55
lines changed

p2p/peer.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import collections
33
import contextlib
44
import datetime
5+
import functools
56
import logging
67
import operator
78
import random
@@ -20,6 +21,7 @@
2021
Dict,
2122
Iterator,
2223
List,
24+
Set,
2325
TYPE_CHECKING,
2426
Tuple,
2527
Type,
@@ -381,14 +383,23 @@ def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -
381383
raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg))
382384

383385
def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
386+
cmd_type = type(cmd)
387+
384388
if self._subscribers:
385-
for subscriber in self._subscribers:
389+
was_added = tuple(
386390
subscriber.add_msg((self, cmd, msg))
391+
for subscriber
392+
in self._subscribers
393+
)
394+
if not any(was_added):
395+
self.logger.warn(
396+
"Peer %s has no subscribers for msg type %s",
397+
self,
398+
cmd_type.__name__,
399+
)
387400
else:
388401
self.logger.warn("Peer %s has no subscribers, discarding %s msg", self, cmd)
389402

390-
cmd_type = type(cmd)
391-
392403
if cmd_type in self.pending_requests:
393404
request, future = self.pending_requests[cmd_type]
394405
try:
@@ -547,10 +558,28 @@ def __hash__(self) -> int:
547558
class PeerSubscriber(ABC):
548559
_msg_queue: 'asyncio.Queue[PEER_MSG_TYPE]' = None
549560

561+
@property
562+
@abstractmethod
563+
def subscription_msg_types(self) -> Set[Type[protocol.Command]]:
564+
"""
565+
The `p2p.protocol.Command` types that this class subscribes to. Any
566+
command which is not in this set will not be passed to this subscriber.
567+
568+
The base command class `p2p.protocol.Command` can be used to enable
569+
**all** command types.
570+
"""
571+
pass
572+
573+
@functools.lru_cache(maxsize=64)
574+
def is_subscription_command(self, cmd_type: Type[protocol.Command]) -> bool:
575+
return bool(self.subscription_msg_types.intersection(
576+
{cmd_type, protocol.Command}
577+
))
578+
550579
@property
551580
@abstractmethod
552581
def msg_queue_maxsize(self) -> int:
553-
raise NotImplementedError("Must be implemented by subclasses")
582+
pass
554583

555584
def register_peer(self, peer: BasePeer) -> None:
556585
"""
@@ -577,16 +606,30 @@ def msg_queue(self) -> 'asyncio.Queue[PEER_MSG_TYPE]':
577606
def queue_size(self) -> int:
578607
return self.msg_queue.qsize()
579608

580-
def add_msg(self, msg: 'PEER_MSG_TYPE') -> None:
609+
def add_msg(self, msg: 'PEER_MSG_TYPE') -> bool:
581610
peer, cmd, _ = msg
611+
612+
if not self.is_subscription_command(type(cmd)):
613+
if hasattr(self, 'logger'):
614+
self.logger.trace( # type: ignore
615+
"Discarding %s msg from %s; not subscribed to msg type; "
616+
"subscriptions: %s",
617+
cmd, peer, self.subscription_msg_types,
618+
)
619+
return False
620+
582621
try:
583-
self.logger.trace( # type: ignore
584-
"Adding %s msg from %s to queue; queue_size=%d", cmd, peer, self.queue_size)
622+
if hasattr(self, 'logger'):
623+
self.logger.trace( # type: ignore
624+
"Adding %s msg from %s to queue; queue_size=%d", cmd, peer, self.queue_size)
585625
self.msg_queue.put_nowait(msg)
626+
return True
586627
except asyncio.queues.QueueFull:
587-
self.logger.warn( # type: ignore
588-
"%s msg queue is full; discarding %s msg from %s",
589-
self.__class__.__name__, cmd, peer)
628+
if hasattr(self, 'logger'):
629+
self.logger.warn( # type: ignore
630+
"%s msg queue is full; discarding %s msg from %s",
631+
self.__class__.__name__, cmd, peer)
632+
return False
590633

591634
@contextlib.contextmanager
592635
def subscribe(self, peer_pool: 'PeerPool') -> Iterator[None]:

tests/p2p/test_peer_subscriber.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
from p2p.peer import PeerSubscriber
6+
from p2p.protocol import Command
7+
8+
from trinity.protocol.eth.peer import ETHPeer
9+
from trinity.protocol.eth.commands import GetBlockHeaders
10+
11+
from tests.trinity.core.peer_helpers import (
12+
get_directly_linked_peers,
13+
)
14+
15+
16+
class HeadersSubscriber(PeerSubscriber):
17+
msg_queue_maxsize = 10
18+
subscription_msg_types = {GetBlockHeaders}
19+
20+
21+
class AllSubscriber(PeerSubscriber):
22+
msg_queue_maxsize = 10
23+
subscription_msg_types = {Command}
24+
25+
26+
@pytest.mark.asyncio
27+
async def test_peer_subscriber_filters_messages(request, event_loop):
28+
peer, remote = await get_directly_linked_peers(
29+
request,
30+
event_loop,
31+
peer1_class=ETHPeer,
32+
peer2_class=ETHPeer,
33+
)
34+
35+
header_subscriber = HeadersSubscriber()
36+
all_subscriber = AllSubscriber()
37+
38+
peer.add_subscriber(header_subscriber)
39+
peer.add_subscriber(all_subscriber)
40+
41+
remote.sub_proto.send_get_node_data([b'\x00' * 32])
42+
remote.sub_proto.send_get_block_headers(0, 1, 0, False)
43+
remote.sub_proto.send_get_node_data([b'\x00' * 32])
44+
remote.sub_proto.send_get_block_headers(1, 1, 0, False)
45+
remote.sub_proto.send_get_node_data([b'\x00' * 32])
46+
47+
# yeild to let remote and peer transmit.
48+
await asyncio.sleep(0.01)
49+
50+
assert header_subscriber.queue_size == 2
51+
assert all_subscriber.queue_size == 5

tests/trinity/core/peer_helpers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
import os
3-
from typing import List
3+
from typing import (
4+
List,
5+
)
46

57
from eth_hash.auto import keccak
68

@@ -16,6 +18,7 @@
1618
from p2p import kademlia
1719
from p2p.auth import decode_authentication
1820
from p2p.peer import BasePeer, PeerPool, PeerSubscriber
21+
from p2p.protocol import Command
1922

2023

2124
from trinity.protocol.les.peer import LESPeer
@@ -174,6 +177,8 @@ async def _run(self) -> None:
174177
class SamplePeerSubscriber(PeerSubscriber):
175178
logger = TraceLogger("")
176179

180+
subscription_msg_types = {Command}
181+
177182
@property
178183
def msg_queue_maxsize(self) -> int:
179184
return 100

trinity/plugins/builtin/tx_pool/pool.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
cast,
33
Callable,
44
Iterable,
5-
List
5+
List,
6+
Set,
7+
Type,
68
)
79
import uuid
810

@@ -21,6 +23,7 @@
2123
PeerPool,
2224
PeerSubscriber,
2325
)
26+
from p2p.protocol import Command
2427
from p2p.service import (
2528
BaseService
2629
)
@@ -59,12 +62,12 @@ def __init__(self,
5962
self._bloom = BloomFilter(max_elements=1000000)
6063
self._bloom_salt = str(uuid.uuid4())
6164

62-
@property
63-
def msg_queue_maxsize(self) -> int:
64-
# This is a rather arbitrary value, but when the sync is operating normally we never see
65-
# the msg queue grow past a few hundred items, so this should be a reasonable limit for
66-
# now.
67-
return 2000
65+
subscription_msg_types: Set[Type[Command]] = {Transactions}
66+
67+
# This is a rather arbitrary value, but when the sync is operating normally we never see
68+
# the msg queue grow past a few hundred items, so this should be a reasonable limit for
69+
# now.
70+
msg_queue_maxsize: int = 2000
6871

6972
async def _run(self) -> None:
7073
self.logger.info("Running Tx Pool")
@@ -74,8 +77,8 @@ async def _run(self) -> None:
7477
peer, cmd, msg = await self.wait(
7578
self.msg_queue.get(), token=self.cancel_token)
7679
peer = cast(ETHPeer, peer)
77-
msg = cast(List[BaseTransactionFields], msg)
7880
if isinstance(cmd, Transactions):
81+
msg = cast(List[BaseTransactionFields], msg)
7982
await self._handle_tx(peer, msg)
8083

8184
async def _handle_tx(self, peer: ETHPeer, txs: List[BaseTransactionFields]) -> None:

trinity/sync/full/chain.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
Dict,
77
List,
88
NamedTuple,
9+
Set,
910
Tuple,
11+
Type,
1012
Union,
1113
cast,
1214
)
@@ -32,9 +34,15 @@
3234
from p2p.exceptions import NoEligiblePeers
3335
from p2p.p2p_proto import DisconnectReason
3436
from p2p.peer import PeerPool
37+
from p2p.protocol import Command
3538

3639
from trinity.db.chain import AsyncChainDB
40+
from trinity.protocol.eth import commands
41+
from trinity.protocol.eth import (
42+
constants as eth_constants,
43+
)
3744
from trinity.protocol.eth.peer import ETHPeer
45+
from trinity.protocol.eth.requests import HeaderRequest
3846
from trinity.protocol.les.peer import LESPeer
3947
from trinity.rlp.block_body import BlockBody
4048
from trinity.sync.base_chain_syncer import BaseHeaderChainSyncer
@@ -66,6 +74,19 @@ def __init__(self,
6674
self._downloaded_receipts: asyncio.Queue[Tuple[ETHPeer, List[DownloadedBlockPart]]] = asyncio.Queue() # noqa: E501
6775
self._downloaded_bodies: asyncio.Queue[Tuple[ETHPeer, List[DownloadedBlockPart]]] = asyncio.Queue() # noqa: E501
6876

77+
subscription_msg_types: Set[Type[Command]] = {
78+
commands.BlockBodies,
79+
commands.Receipts,
80+
commands.NewBlock,
81+
commands.GetBlockHeaders,
82+
commands.BlockHeaders,
83+
commands.GetBlockBodies,
84+
commands.GetReceipts,
85+
commands.GetNodeData,
86+
commands.Transactions,
87+
commands.NodeData,
88+
}
89+
6990
async def _calculate_td(self, headers: Tuple[BlockHeader, ...]) -> int:
7091
"""Return the score (total difficulty) of the last header in the given list.
7192
@@ -191,7 +212,6 @@ def _request_block_parts(
191212
target_td: int,
192213
headers: List[BlockHeader],
193214
request_func: Callable[[ETHPeer, List[BlockHeader]], None]) -> int:
194-
from trinity.protocol.eth.peer import ETHPeer # noqa: F811
195215
peers = self.peer_pool.get_peers(target_td)
196216
if not peers:
197217
raise NoEligiblePeers()
@@ -235,12 +255,6 @@ def request_receipts(self, target_td: int, headers: List[BlockHeader]) -> int:
235255

236256
async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
237257
msg: protocol._DecodedMsgType) -> None:
238-
from trinity.protocol.eth.peer import ETHPeer # noqa: F811
239-
from trinity.protocol.eth import commands
240-
from trinity.protocol.eth import (
241-
constants as eth_constants,
242-
)
243-
244258
peer = cast(ETHPeer, peer)
245259

246260
if isinstance(cmd, commands.BlockBodies):
@@ -318,8 +332,6 @@ async def _handle_get_block_headers(
318332
self,
319333
peer: ETHPeer,
320334
query: Dict[str, Any]) -> None:
321-
from trinity.protocol.eth.requests import HeaderRequest # noqa: F811
322-
323335
self.logger.debug("Peer %s made header request: %s", peer, query)
324336
request = HeaderRequest(
325337
query['block_number_or_hash'],

trinity/sync/full/state.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
List,
1212
Set,
1313
Tuple,
14+
Type,
1415
Union,
1516
)
1617

@@ -88,12 +89,21 @@ def __init__(self,
8889
self._peer_missing_nodes: Dict[ETHPeer, Set[Hash32]] = collections.defaultdict(set)
8990
self._executor = get_asyncio_executor()
9091

91-
@property
92-
def msg_queue_maxsize(self) -> int:
93-
# This is a rather arbitrary value, but when the sync is operating normally we never see
94-
# the msg queue grow past a few hundred items, so this should be a reasonable limit for
95-
# now.
96-
return 2000
92+
# Throughout the whole state sync our chain head is fixed, so it makes sense to ignore
93+
# messages related to new blocks/transactions, but we must handle requests for data from
94+
# other peers or else they will disconnect from us.
95+
subscription_msg_types: Set[Type[Command]] = {
96+
commands.NodeData,
97+
commands.GetBlockHeaders,
98+
commands.GetBlockBodies,
99+
commands.GetReceipts,
100+
commands.GetNodeData,
101+
}
102+
103+
# This is a rather arbitrary value, but when the sync is operating normally we never see
104+
# the msg queue grow past a few hundred items, so this should be a reasonable limit for
105+
# now.
106+
msg_queue_maxsize: int = 2000
97107

98108
def deregister_peer(self, peer: BasePeer) -> None:
99109
# Use .pop() with a default value as it's possible we never requested anything to this
@@ -154,13 +164,7 @@ async def _process_nodes(self, nodes: Iterable[Tuple[Hash32, bytes]]) -> None:
154164

155165
async def _handle_msg(
156166
self, peer: ETHPeer, cmd: Command, msg: _DecodedMsgType) -> None:
157-
# Throughout the whole state sync our chain head is fixed, so it makes sense to ignore
158-
# messages related to new blocks/transactions, but we must handle requests for data from
159-
# other peers or else they will disconnect from us.
160-
ignored_commands = (commands.Transactions, commands.NewBlock, commands.NewBlockHashes)
161-
if isinstance(cmd, ignored_commands):
162-
pass
163-
elif isinstance(cmd, commands.NodeData):
167+
if isinstance(cmd, commands.NodeData):
164168
msg = cast(List[bytes], msg)
165169
if peer not in self.request_tracker.active_requests:
166170
# This is probably a batch that we retried after a timeout and ended up receiving

0 commit comments

Comments
 (0)