Skip to content

Commit 9b80f45

Browse files
authored
Merge pull request #1184 from carver/request-manager-refactor
Reducing responsibility of BaseRequestManager
2 parents 2eaf026 + aa500c2 commit 9b80f45

37 files changed

+1152
-858
lines changed

p2p/peer.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Dict,
2121
Iterator,
2222
List,
23+
NamedTuple,
2324
Set,
2425
TYPE_CHECKING,
2526
Tuple,
@@ -209,7 +210,7 @@ async def send_sub_proto_handshake(self) -> None:
209210

210211
@abstractmethod
211212
async def process_sub_proto_handshake(
212-
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
213+
self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
213214
raise NotImplementedError("Must be implemented by subclasses")
214215

215216
@contextlib.contextmanager
@@ -365,7 +366,7 @@ async def _run(self) -> None:
365366
self.logger.debug("%s disconnected: %s", self, e)
366367
return
367368

368-
async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
369+
async def read_msg(self) -> Tuple[protocol.Command, protocol.PayloadType]:
369370
header_data = await self.read(HEADER_LEN + MAC_LEN)
370371
header = self.decrypt_header(header_data)
371372
frame_size = self.get_frame_size(header)
@@ -392,7 +393,7 @@ async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
392393
self.received_msgs[cmd] += 1
393394
return cmd, decoded_msg
394395

395-
def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
396+
def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
396397
"""Handle the base protocol (P2P) messages."""
397398
if isinstance(cmd, Disconnect):
398399
msg = cast(Dict[str, Any], msg)
@@ -406,12 +407,12 @@ def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -
406407
else:
407408
raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg))
408409

409-
def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
410+
def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
410411
cmd_type = type(cmd)
411412

412413
if self._subscribers:
413414
was_added = tuple(
414-
subscriber.add_msg((self, cmd, msg))
415+
subscriber.add_msg(PeerMessage(self, cmd, msg))
415416
for subscriber
416417
in self._subscribers
417418
)
@@ -424,14 +425,14 @@ def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgT
424425
else:
425426
self.logger.warn("Peer %s has no subscribers, discarding %s msg", self, cmd)
426427

427-
def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
428+
def process_msg(self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
428429
if cmd.is_base_protocol:
429430
self.handle_p2p_msg(cmd, msg)
430431
else:
431432
self.handle_sub_proto_msg(cmd, msg)
432433

433434
async def process_p2p_handshake(
434-
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
435+
self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
435436
msg = cast(Dict[str, Any], msg)
436437
if not isinstance(cmd, Hello):
437438
await self.disconnect(DisconnectReason.bad_protocol)
@@ -478,7 +479,7 @@ def decrypt_header(self, data: bytes) -> bytes:
478479
self.ingress_mac.update(sxor(aes, header_ciphertext))
479480
expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN]
480481
if not bytes_eq(expected_header_mac, header_mac):
481-
raise DecryptionError('Invalid header mac: expected %s, got %s'.format(
482+
raise DecryptionError('Invalid header mac: expected {}, got {}'.format(
482483
expected_header_mac, header_mac))
483484
return self.aes_dec.update(header_ciphertext)
484485

@@ -563,8 +564,14 @@ def __hash__(self) -> int:
563564
return hash(self.remote)
564565

565566

567+
class PeerMessage(NamedTuple):
568+
peer: BasePeer
569+
command: protocol.Command
570+
payload: protocol.PayloadType
571+
572+
566573
class PeerSubscriber(ABC):
567-
_msg_queue: 'asyncio.Queue[PEER_MSG_TYPE]' = None
574+
_msg_queue: 'asyncio.Queue[PeerMessage]' = None
568575

569576
@property
570577
@abstractmethod
@@ -609,7 +616,7 @@ def deregister_peer(self, peer: BasePeer) -> None:
609616
pass
610617

611618
@property
612-
def msg_queue(self) -> 'asyncio.Queue[PEER_MSG_TYPE]':
619+
def msg_queue(self) -> 'asyncio.Queue[PeerMessage]':
613620
if self._msg_queue is None:
614621
self._msg_queue = asyncio.Queue(maxsize=self.msg_queue_maxsize)
615622
return self._msg_queue
@@ -618,26 +625,29 @@ def msg_queue(self) -> 'asyncio.Queue[PEER_MSG_TYPE]':
618625
def queue_size(self) -> int:
619626
return self.msg_queue.qsize()
620627

621-
def add_msg(self, msg: 'PEER_MSG_TYPE') -> bool:
628+
def add_msg(self, msg: PeerMessage) -> bool:
622629
peer, cmd, _ = msg
623630

624631
if not self.is_subscription_command(type(cmd)):
625-
self.logger.trace( # type: ignore
626-
"Discarding %s msg from %s; not subscribed to msg type; "
627-
"subscriptions: %s",
628-
cmd, peer, self.subscription_msg_types,
629-
)
632+
if hasattr(self, 'logger'):
633+
self.logger.trace( # type: ignore
634+
"Discarding %s msg from %s; not subscribed to msg type; "
635+
"subscriptions: %s",
636+
cmd, peer, self.subscription_msg_types,
637+
)
630638
return False
631639

632640
try:
633-
self.logger.trace( # type: ignore
634-
"Adding %s msg from %s to queue; queue_size=%d", cmd, peer, self.queue_size)
641+
if hasattr(self, 'logger'):
642+
self.logger.trace( # type: ignore
643+
"Adding %s msg from %s to queue; queue_size=%d", cmd, peer, self.queue_size)
635644
self.msg_queue.put_nowait(msg)
636645
return True
637646
except asyncio.queues.QueueFull:
638-
self.logger.warn( # type: ignore
639-
"%s msg queue is full; discarding %s msg from %s",
640-
self.__class__.__name__, cmd, peer)
647+
if hasattr(self, 'logger'):
648+
self.logger.warn( # type: ignore
649+
"%s msg queue is full; discarding %s msg from %s",
650+
self.__class__.__name__, cmd, peer)
641651
return False
642652

643653
@contextlib.contextmanager
@@ -663,7 +673,7 @@ class MsgBuffer(PeerSubscriber):
663673
subscription_msg_types = {protocol.Command}
664674

665675
@to_tuple
666-
def get_messages(self) -> Iterator['PEER_MSG_TYPE']:
676+
def get_messages(self) -> Iterator[PeerMessage]:
667677
while not self.msg_queue.empty():
668678
yield self.msg_queue.get_nowait()
669679

@@ -740,7 +750,7 @@ async def start_peer(self, peer: BasePeer) -> None:
740750

741751
def _add_peer(self,
742752
peer: BasePeer,
743-
msgs: Tuple[Tuple[protocol.Command, protocol._DecodedMsgType], ...]) -> None:
753+
msgs: Tuple[Tuple[protocol.Command, protocol.PayloadType], ...]) -> None:
744754
"""Add the given peer to the pool.
745755
746756
Appart from adding it to our list of connected nodes and adding each of our subscriber's
@@ -753,7 +763,7 @@ def _add_peer(self,
753763
subscriber.register_peer(peer)
754764
peer.add_subscriber(subscriber)
755765
for cmd, msg in msgs:
756-
subscriber.add_msg((peer, cmd, msg))
766+
subscriber.add_msg(PeerMessage(peer, cmd, msg))
757767

758768
async def _run(self) -> None:
759769
# FIXME: PeerPool should probably no longer be a BaseService, but for now we're keeping it
@@ -1006,9 +1016,6 @@ def __init__(self,
10061016
self.genesis_hash = genesis_hash
10071017

10081018

1009-
PEER_MSG_TYPE = Tuple[BasePeer, protocol.Command, protocol._DecodedMsgType]
1010-
1011-
10121019
def _test() -> None:
10131020
"""
10141021
Create a Peer instance connected to a local geth instance and log messages exchanged with it.
@@ -1066,8 +1073,8 @@ async def request_stuff() -> None:
10661073
hashes = tuple(header.hash for header in headers)
10671074
if peer_class == ETHPeer:
10681075
peer = cast(ETHPeer, peer)
1069-
peer.sub_proto._send_get_block_bodies(hashes)
1070-
peer.sub_proto._send_get_receipts(hashes)
1076+
peer.sub_proto.send_get_block_bodies(hashes)
1077+
peer.sub_proto.send_get_receipts(hashes)
10711078
else:
10721079
peer = cast(LESPeer, peer)
10731080
request_id = 1

p2p/protocol.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
from abc import ABC, abstractmethod
12
import logging
23
import struct
34
from typing import (
45
Any,
56
Dict,
7+
Generic,
68
List,
79
Tuple,
810
Type,
11+
TypeVar,
912
TYPE_CHECKING,
1013
Union,
1114
)
@@ -20,19 +23,23 @@
2023
)
2124
from p2p.utils import get_devp2p_cmd_id
2225

23-
2426
# Workaround for import cycles caused by type annotations:
2527
# http://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
2628
if TYPE_CHECKING:
2729
from p2p.peer import ChainInfo, BasePeer # noqa: F401
2830

29-
30-
_DecodedMsgType = Union[
31+
PayloadType = Union[
3132
Dict[str, Any],
3233
List[rlp.Serializable],
3334
Tuple[rlp.Serializable, ...],
3435
]
3536

37+
# A payload to be delivered with a request
38+
TRequestPayload = TypeVar('TRequestPayload', bound=PayloadType, covariant=True)
39+
40+
# for backwards compatibility for internal references in p2p:
41+
_DecodedMsgType = PayloadType
42+
3643

3744
class Command:
3845
_cmd_id: int = None
@@ -60,7 +67,7 @@ def is_base_protocol(self) -> bool:
6067
def __str__(self) -> str:
6168
return "{} (cmd_id={})".format(self.__class__.__name__, self.cmd_id)
6269

63-
def encode_payload(self, data: Union[_DecodedMsgType, sedes.CountableList]) -> bytes:
70+
def encode_payload(self, data: Union[PayloadType, sedes.CountableList]) -> bytes:
6471
if isinstance(data, dict): # convert dict to ordered list
6572
if not isinstance(self.structure, list):
6673
raise ValueError("Command.structure must be a list when data is a dict")
@@ -76,7 +83,7 @@ def encode_payload(self, data: Union[_DecodedMsgType, sedes.CountableList]) -> b
7683
encoder = sedes.List([type_ for _, type_ in self.structure])
7784
return rlp.encode(data, sedes=encoder)
7885

79-
def decode_payload(self, rlp_data: bytes) -> _DecodedMsgType:
86+
def decode_payload(self, rlp_data: bytes) -> PayloadType:
8087
if isinstance(self.structure, sedes.CountableList):
8188
decoder = self.structure
8289
else:
@@ -97,13 +104,13 @@ def decode_payload(self, rlp_data: bytes) -> _DecodedMsgType:
97104
in zip(self.structure, data)
98105
}
99106

100-
def decode(self, data: bytes) -> _DecodedMsgType:
107+
def decode(self, data: bytes) -> PayloadType:
101108
packet_type = get_devp2p_cmd_id(data)
102109
if packet_type != self.cmd_id:
103110
raise MalformedMessage("Wrong packet type: {}".format(packet_type))
104111
return self.decode_payload(data[1:])
105112

106-
def encode(self, data: _DecodedMsgType) -> Tuple[bytes, bytes]:
113+
def encode(self, data: PayloadType) -> Tuple[bytes, bytes]:
107114
payload = self.encode_payload(data)
108115
enc_cmd_id = rlp.encode(self.cmd_id, sedes=rlp.sedes.big_endian_int)
109116
frame_size = len(enc_cmd_id) + len(payload)
@@ -123,6 +130,24 @@ def encode(self, data: _DecodedMsgType) -> Tuple[bytes, bytes]:
123130
return header, body
124131

125132

133+
class BaseRequest(ABC, Generic[TRequestPayload]):
134+
"""
135+
Must define command_payload during init. This is the data that will
136+
be sent to the peer with the request command.
137+
"""
138+
command_payload: TRequestPayload
139+
140+
@property
141+
@abstractmethod
142+
def cmd_type(self) -> Type[Command]:
143+
raise NotImplementedError
144+
145+
@property
146+
@abstractmethod
147+
def response_type(self) -> Type[Command]:
148+
raise NotImplementedError
149+
150+
126151
class Protocol:
127152
logger = logging.getLogger("p2p.protocol.Protocol")
128153
name: str = None
@@ -135,11 +160,20 @@ def __init__(self, peer: 'BasePeer', cmd_id_offset: int) -> None:
135160
self.peer = peer
136161
self.cmd_id_offset = cmd_id_offset
137162
self.commands = [cmd_class(cmd_id_offset) for cmd_class in self._commands]
163+
self.cmd_by_type = {cmd_class: cmd_class(cmd_id_offset) for cmd_class in self._commands}
138164
self.cmd_by_id = dict((cmd.cmd_id, cmd) for cmd in self.commands)
139165

140166
def send(self, header: bytes, body: bytes) -> None:
141167
self.peer.send(header, body)
142168

169+
def send_request(self, request: BaseRequest[PayloadType]) -> None:
170+
command = self.cmd_by_type[request.cmd_type]
171+
header, body = command.encode(request.command_payload)
172+
self.send(header, body)
173+
174+
def supports_command(self, cmd_type: Type[Command]) -> bool:
175+
return cmd_type in self.cmd_by_type
176+
143177
def __repr__(self) -> str:
144178
return "(%s, %d)" % (self.name, self.version)
145179

tests/p2p/test_peer_collect_sub_proto_msgs.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55

66
from trinity.protocol.eth.peer import ETHPeer
77
from trinity.protocol.eth.commands import GetBlockHeaders, GetNodeData
8-
from trinity.protocol.eth.requests import (
9-
HeaderRequest,
10-
NodeDataRequest,
11-
)
128

139
from tests.trinity.core.peer_helpers import (
1410
get_directly_linked_peers,
@@ -30,11 +26,11 @@ async def test_peer_subscriber_filters_messages(request, event_loop):
3026

3127
with peer.collect_sub_proto_messages() as collector:
3228
assert collector in peer._subscribers
33-
remote.sub_proto.send_get_node_data(NodeDataRequest([b'\x00' * 32]))
34-
remote.sub_proto.send_get_block_headers(HeaderRequest(0, 1, 0, False))
35-
remote.sub_proto.send_get_node_data(NodeDataRequest([b'\x00' * 32]))
36-
remote.sub_proto.send_get_block_headers(HeaderRequest(1, 1, 0, False))
37-
remote.sub_proto.send_get_node_data(NodeDataRequest([b'\x00' * 32]))
29+
remote.sub_proto.send_get_node_data([b'\x00' * 32])
30+
remote.sub_proto.send_get_block_headers(0, 1, 0, False)
31+
remote.sub_proto.send_get_node_data([b'\x00' * 32])
32+
remote.sub_proto.send_get_block_headers(1, 1, 0, False)
33+
remote.sub_proto.send_get_node_data([b'\x00' * 32])
3834
await asyncio.sleep(0.01)
3935

4036
assert collector not in peer._subscribers
@@ -51,7 +47,7 @@ async def test_peer_subscriber_filters_messages(request, event_loop):
5147
assert isinstance(all_messages[4][1], GetNodeData)
5248

5349
# make sure it isn't still collecting
54-
remote.sub_proto.send_get_block_headers(HeaderRequest(1, 1, 0, False))
50+
remote.sub_proto.send_get_block_headers(1, 1, 0, False)
5551

5652
await asyncio.sleep(0.01)
5753

tests/p2p/test_peer_subscriber.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88

99
from trinity.protocol.eth.peer import ETHPeer
1010
from trinity.protocol.eth.commands import GetBlockHeaders
11-
from trinity.protocol.eth.requests import (
12-
HeaderRequest,
13-
NodeDataRequest,
14-
)
1511

1612
from tests.trinity.core.peer_helpers import (
1713
get_directly_linked_peers,
@@ -48,11 +44,11 @@ async def test_peer_subscriber_filters_messages(request, event_loop):
4844
peer.add_subscriber(header_subscriber)
4945
peer.add_subscriber(all_subscriber)
5046

51-
remote.sub_proto.send_get_node_data(NodeDataRequest([b'\x00' * 32]))
52-
remote.sub_proto.send_get_block_headers(HeaderRequest(0, 1, 0, False))
53-
remote.sub_proto.send_get_node_data(NodeDataRequest([b'\x00' * 32]))
54-
remote.sub_proto.send_get_block_headers(HeaderRequest(1, 1, 0, False))
55-
remote.sub_proto.send_get_node_data(NodeDataRequest([b'\x00' * 32]))
47+
remote.sub_proto.send_get_node_data(tuple([b'\x00' * 32]))
48+
remote.sub_proto.send_get_block_headers(0, 1, 0, False)
49+
remote.sub_proto.send_get_node_data(tuple([b'\x00' * 32]))
50+
remote.sub_proto.send_get_block_headers(1, 1, 0, False)
51+
remote.sub_proto.send_get_node_data(tuple([b'\x00' * 32]))
5652

5753
# yeild to let remote and peer transmit.
5854
await asyncio.sleep(0.01)

0 commit comments

Comments
 (0)