Skip to content

Commit 47a2155

Browse files
authored
Merge pull request #1138 from pipermerriam/piper/establish-round-trip-request-response-pattern
Establish pattern for generic round trip request response handling
2 parents bf1d3c2 + 5366937 commit 47a2155

33 files changed

+583
-254
lines changed

p2p/peer.py

Lines changed: 60 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import operator
88
import random
99
import struct
10-
import time
1110
from abc import (
1211
ABC,
1312
abstractmethod
@@ -39,6 +38,7 @@
3938

4039
from eth_utils import (
4140
decode_hex,
41+
to_tuple,
4242
)
4343

4444
from eth_typing import BlockNumber, Hash32
@@ -155,14 +155,6 @@ class BasePeer(BaseService):
155155
head_td: int = None
156156
head_hash: Hash32 = None
157157

158-
# TODO: Instead of a fixed timeout, we should instead monitor response
159-
# times for the peer and adjust our timeout accordingly
160-
_response_timeout = 60
161-
pending_requests: Dict[
162-
Type[protocol.Command],
163-
Tuple['BaseRequest', 'asyncio.Future[protocol._DecodedMsgType]'],
164-
]
165-
166158
def __init__(self,
167159
remote: Node,
168160
privkey: datatypes.PrivateKey,
@@ -189,8 +181,6 @@ def __init__(self,
189181
self.start_time = datetime.datetime.now()
190182
self.received_msgs: Dict[protocol.Command, int] = collections.defaultdict(int)
191183

192-
self.pending_requests = {}
193-
194184
self.egress_mac = egress_mac
195185
self.ingress_mac = ingress_mac
196186
# FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32
@@ -210,6 +200,18 @@ async def process_sub_proto_handshake(
210200
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
211201
raise NotImplementedError("Must be implemented by subclasses")
212202

203+
@contextlib.contextmanager
204+
def collect_sub_proto_messages(self) -> Iterator['MsgBuffer']:
205+
"""
206+
Can be used to gather up all messages that are sent to the peer.
207+
"""
208+
if not self.is_running:
209+
raise RuntimeError("Cannot collect messages if peer is not running")
210+
msg_buffer = MsgBuffer()
211+
212+
with msg_buffer.subscribe_peer(self):
213+
yield msg_buffer
214+
213215
@property
214216
def received_msgs_count(self) -> int:
215217
return sum(self.received_msgs.values())
@@ -400,22 +402,6 @@ def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgT
400402
else:
401403
self.logger.warn("Peer %s has no subscribers, discarding %s msg", self, cmd)
402404

403-
if cmd_type in self.pending_requests:
404-
request, future = self.pending_requests[cmd_type]
405-
try:
406-
request.validate_response(msg)
407-
except ValidationError as err:
408-
self.logger.debug(
409-
"Response validation failure for pending %s request from peer %s: %s",
410-
cmd_type.__name__,
411-
self,
412-
err,
413-
)
414-
pass
415-
else:
416-
future.set_result(msg)
417-
self.pending_requests.pop(cmd_type)
418-
419405
def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
420406
if cmd.is_base_protocol:
421407
self.handle_p2p_msg(cmd, msg)
@@ -640,6 +626,25 @@ def subscribe(self, peer_pool: 'PeerPool') -> Iterator[None]:
640626
finally:
641627
peer_pool.unsubscribe(self)
642628

629+
@contextlib.contextmanager
630+
def subscribe_peer(self, peer: BasePeer) -> Iterator[None]:
631+
peer.add_subscriber(self)
632+
try:
633+
yield
634+
finally:
635+
peer.remove_subscriber(self)
636+
637+
638+
class MsgBuffer(PeerSubscriber):
639+
logger = logging.getLogger('p2p.peer.MsgBuffer')
640+
msg_queue_maxsize = 500
641+
subscription_msg_types = {protocol.Command}
642+
643+
@to_tuple
644+
def get_messages(self) -> Iterator['PEER_MSG_TYPE']:
645+
while not self.msg_queue.empty():
646+
yield self.msg_queue.get_nowait()
647+
643648

644649
class PeerPool(BaseService, AsyncIterable[BasePeer]):
645650
"""
@@ -695,28 +700,33 @@ def unsubscribe(self, subscriber: PeerSubscriber) -> None:
695700
peer.remove_subscriber(subscriber)
696701

697702
async def start_peer(self, peer: BasePeer) -> None:
703+
asyncio.ensure_future(peer.run())
704+
await self.wait(peer.events.started.wait(), timeout=1)
698705
try:
699706
# Although connect() may seem like a more appropriate place to perform the DAO fork
700707
# check, we do it here because we want to perform it for incoming peer connections as
701708
# well.
702-
msgs = await self.ensure_same_side_on_dao_fork(peer)
709+
with peer.collect_sub_proto_messages() as buffer:
710+
await self.ensure_same_side_on_dao_fork(peer)
703711
except DAOForkCheckFailure as err:
704712
self.logger.debug("DAO fork check with %s failed: %s", peer, err)
705713
await peer.disconnect(DisconnectReason.useless_peer)
706714
return
707-
asyncio.ensure_future(peer.run(finished_callback=self._peer_finished))
708-
self._add_peer(peer, msgs)
715+
else:
716+
msgs = tuple((cmd, msg) for _, cmd, msg in buffer.get_messages())
717+
self._add_peer(peer, msgs)
709718

710719
def _add_peer(self,
711720
peer: BasePeer,
712-
msgs: List[Tuple[protocol.Command, protocol._DecodedMsgType]]) -> None:
721+
msgs: Tuple[Tuple[protocol.Command, protocol._DecodedMsgType], ...]) -> None:
713722
"""Add the given peer to the pool.
714723
715724
Appart from adding it to our list of connected nodes and adding each of our subscriber's
716725
to the peer, we also add the given messages to our subscriber's queues.
717726
"""
718727
self.logger.info('Adding %s to pool', peer)
719728
self.connected_nodes[peer.remote] = peer
729+
peer.add_finished_callback(self._peer_finished)
720730
for subscriber in self._subscribers:
721731
subscriber.register_peer(peer)
722732
peer.add_subscriber(subscriber)
@@ -787,15 +797,13 @@ async def connect_to_nodes(self, nodes: Iterator[Node]) -> None:
787797
await self.start_peer(peer)
788798

789799
async def ensure_same_side_on_dao_fork(
790-
self, peer: BasePeer) -> List[Tuple[protocol.Command, protocol._DecodedMsgType]]:
800+
self, peer: BasePeer) -> None:
791801
"""Ensure we're on the same side of the DAO fork as the given peer.
792802
793803
In order to do that we have to request the DAO fork block and its parent, but while we
794804
wait for that we may receive other messages from the peer, which are returned so that they
795805
can be re-added to our subscribers' queues when the peer is finally added to the pool.
796806
"""
797-
from trinity.protocol.base_block_headers import BaseBlockHeaders
798-
msgs = []
799807
for start_block, vm_class in self.vm_configuration:
800808
if not issubclass(vm_class, HomesteadVM):
801809
continue
@@ -806,46 +814,33 @@ async def ensure_same_side_on_dao_fork(
806814
break
807815

808816
start_block = vm_class.dao_fork_block_number - 1
809-
# TODO: This can be either an `ETHPeer` or an `LESPeer`. Will be
810-
# fixed once full awaitable request API is completed.
811-
request = peer.request_block_headers( # type: ignore
812-
start_block,
813-
max_headers=2,
814-
reverse=False,
815-
)
816-
start = time.time()
817+
817818
try:
818-
while True:
819-
elapsed = int(time.time() - start)
820-
remaining_timeout = max(0, CHAIN_SPLIT_CHECK_TIMEOUT - elapsed)
821-
cmd, msg = await self.wait(
822-
peer.read_msg(), timeout=remaining_timeout)
823-
if isinstance(cmd, BaseBlockHeaders):
824-
headers = cmd.extract_headers(msg)
825-
break
826-
else:
827-
msgs.append((cmd, msg))
828-
continue
819+
headers = await peer.requests.get_block_headers( # type: ignore
820+
start_block,
821+
max_headers=2,
822+
reverse=False,
823+
timeout=CHAIN_SPLIT_CHECK_TIMEOUT,
824+
)
825+
829826
except (TimeoutError, PeerConnectionLost) as err:
830827
raise DAOForkCheckFailure(
831-
"Timed out waiting for DAO fork header from {}: {}".format(peer, err))
828+
"Timed out waiting for DAO fork header from {}: {}".format(peer, err)
829+
) from err
832830
except MalformedMessage as err:
833831
raise DAOForkCheckFailure(
834832
"Malformed message while doing DAO fork check with {0}: {1}".format(
835833
peer, err,
836834
)
837835
) from err
838-
839-
try:
840-
request.validate_headers(headers)
841836
except ValidationError as err:
842837
raise DAOForkCheckFailure(
843838
"Invalid header response during DAO fork check: {}".format(err)
844-
)
839+
) from err
845840

846841
if len(headers) != 2:
847842
raise DAOForkCheckFailure(
848-
"Peer failed to return all requested headers for DAO fork check"
843+
"Peer %s failed to return DAO fork check headers".format(peer)
849844
)
850845
else:
851846
parent, header = headers
@@ -855,8 +850,6 @@ async def ensure_same_side_on_dao_fork(
855850
except EthValidationError as err:
856851
raise DAOForkCheckFailure("Peer failed DAO fork check validation: {}".format(err))
857852

858-
return msgs
859-
860853
def _peer_finished(self, peer: BaseService) -> None:
861854
"""Remove the given peer from our list of connected nodes.
862855
This is passed as a callback to be called when a peer finishes.
@@ -1008,7 +1001,9 @@ def _test() -> None:
10081001
from eth.chains.ropsten import RopstenChain, ROPSTEN_GENESIS_HEADER, ROPSTEN_VM_CONFIGURATION
10091002
from eth.db.backends.memory import MemoryDB
10101003
from trinity.protocol.eth.peer import ETHPeer
1004+
from trinity.protocol.eth.requests import HeaderRequest as ETHHeaderRequest
10111005
from trinity.protocol.les.peer import LESPeer
1006+
from trinity.protocol.les.requests import HeaderRequest as LESHeaderRequest
10121007
from tests.p2p.integration_test_helpers import FakeAsyncHeaderDB, connect_to_peers_loop
10131008
logging.basicConfig(level=TRACE_LEVEL_NUM, format='%(asctime)s %(levelname)s: %(message)s')
10141009

@@ -1041,13 +1036,15 @@ async def request_stuff() -> None:
10411036
'0x59af08ab31822c992bb3dad92ddb68d820aa4c69e9560f07081fa53f1009b152')
10421037
if peer_class == ETHPeer:
10431038
peer = cast(ETHPeer, peer)
1044-
peer.sub_proto.send_get_block_headers(block_hash, 1, 0, False)
1039+
peer.sub_proto.send_get_block_headers(ETHHeaderRequest(block_hash, 1, 0, False))
10451040
peer.sub_proto.send_get_block_bodies([block_hash])
10461041
peer.sub_proto.send_get_receipts([block_hash])
10471042
else:
10481043
peer = cast(LESPeer, peer)
10491044
request_id = 1
1050-
peer.sub_proto.send_get_block_headers(block_hash, 1, 0, False, request_id)
1045+
peer.sub_proto.send_get_block_headers(
1046+
LESHeaderRequest(block_hash, 1, 0, False, request_id)
1047+
)
10511048
peer.sub_proto.send_get_block_bodies([block_hash], request_id + 1)
10521049
peer.sub_proto.send_get_receipts(block_hash, request_id + 2)
10531050

p2p/service.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,19 @@
1717
from p2p.cancellable import CancellableMixin
1818

1919

20+
class ServiceEvents:
21+
def __init__(self) -> None:
22+
self.started = asyncio.Event()
23+
self.stopped = asyncio.Event()
24+
self.cleaned_up = asyncio.Event()
25+
self.cancelled = asyncio.Event()
26+
self.finished = asyncio.Event()
27+
28+
2029
class BaseService(ABC, CancellableMixin):
2130
logger: TraceLogger = None
2231
_child_services: List['BaseService']
32+
_finished_callbacks: List[Callable[['BaseService'], None]]
2333
# Number of seconds cancel() will wait for run() to finish.
2434
_wait_until_finished_timeout = 5
2535

@@ -32,8 +42,9 @@ def __init__(self, token: CancelToken=None, loop: asyncio.AbstractEventLoop = No
3242
TraceLogger, logging.getLogger(self.__module__ + '.' + self.__class__.__name__))
3343

3444
self._run_lock = asyncio.Lock()
35-
self.cleaned_up = asyncio.Event()
45+
self.events = ServiceEvents()
3646
self._child_services = []
47+
self._finished_callbacks = []
3748

3849
self.loop = loop
3950
base_token = CancelToken(type(self).__name__, loop=loop)
@@ -56,8 +67,12 @@ async def run(
5667
elif self.cancel_token.triggered:
5768
raise RuntimeError("Cannot restart a service that has already been cancelled")
5869

70+
if finished_callback:
71+
self._finished_callbacks.append(finished_callback)
72+
5973
try:
6074
async with self._run_lock:
75+
self.events.started.set()
6176
await self._run()
6277
except OperationCancelled as e:
6378
self.logger.info("%s finished: %s", self, e)
@@ -66,19 +81,20 @@ async def run(
6681
finally:
6782
# Trigger our cancel token to ensure all pending asyncio tasks and background
6883
# coroutines started by this service exit cleanly.
84+
self.events.cancelled.set()
6985
self.cancel_token.trigger()
7086

7187
await self.cleanup()
7288

73-
from p2p.peer import BasePeer # type: ignore
74-
if finished_callback is not None:
75-
finished_callback(self)
76-
elif isinstance(self, BasePeer):
77-
# XXX: Only added to help debug https://github.com/ethereum/py-evm/issues/1023;
78-
# should be removed eventually.
79-
self.logger.warn("%s finished but had no finished_callback", self)
89+
for callback in self._finished_callbacks:
90+
callback(self)
91+
92+
self.events.finished.set()
8093
self.logger.debug("%s halted cleanly", self)
8194

95+
def add_finished_callback(self, finished_callback: Callable[['BaseService'], None]) -> None:
96+
self._finished_callbacks.append(finished_callback)
97+
8298
def run_child_service(self, child_service: 'BaseService') -> 'asyncio.Future[Any]':
8399
"""
84100
Run a child service and keep a reference to it to be considered during the cleanup.
@@ -94,13 +110,12 @@ async def cleanup(self) -> None:
94110
The ``_cleanup()`` coroutine is invoked before the child services may have finished
95111
their cleanup.
96112
"""
97-
98113
await asyncio.gather(*[
99-
child_service.cleaned_up.wait()
114+
child_service.events.cleaned_up.wait()
100115
for child_service in self._child_services],
101116
self._cleanup()
102117
)
103-
self.cleaned_up.set()
118+
self.events.cleaned_up.set()
104119

105120
async def cancel(self) -> None:
106121
"""Trigger the CancelToken and wait for the cleaned_up event to be set."""
@@ -111,10 +126,11 @@ async def cancel(self) -> None:
111126
raise RuntimeError("Cannot cancel a service that has not been started")
112127

113128
self.logger.debug("Cancelling %s", self)
129+
self.events.cancelled.set()
114130
self.cancel_token.trigger()
115131
try:
116132
await asyncio.wait_for(
117-
self.cleaned_up.wait(), timeout=self._wait_until_finished_timeout)
133+
self.events.cleaned_up.wait(), timeout=self._wait_until_finished_timeout)
118134
except asyncio.futures.TimeoutError:
119135
self.logger.info("Timed out waiting for %s to finish its cleanup, exiting anyway", self)
120136
else:
@@ -131,7 +147,10 @@ async def threadsafe_cancel(self) -> None:
131147
:param poll_period: how many seconds to wait in between each check for service cleanup
132148
"""
133149
asyncio.run_coroutine_threadsafe(self.cancel(), loop=self.loop)
134-
await asyncio.wait_for(self.cleaned_up.wait(), timeout=self._wait_until_finished_timeout)
150+
await asyncio.wait_for(
151+
self.events.cleaned_up.wait(),
152+
timeout=self._wait_until_finished_timeout,
153+
)
135154

136155
async def sleep(self, delay: float) -> None:
137156
"""Coroutine that completes after a given time (in seconds)."""

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def vm_logger(namespaces=LOGGING_NAMESPACES):
3333

3434
handler = logging.StreamHandler(sys.stdout)
3535

36-
# level = TRACE_LEVEL_NUM
36+
# level = 5 # TRACE
3737
# level = logging.DEBUG
3838
# level = logging.INFO
3939
level = logging.ERROR

0 commit comments

Comments
 (0)