Skip to content

Commit 73ea0e5

Browse files
committed
PR Feedback
1 parent 69afd31 commit 73ea0e5

File tree

13 files changed

+125
-64
lines changed

13 files changed

+125
-64
lines changed

p2p/peer.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from eth_utils import (
4040
decode_hex,
41+
to_tuple,
4142
)
4243

4344
from eth_typing import BlockNumber, Hash32
@@ -199,6 +200,18 @@ async def process_sub_proto_handshake(
199200
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
200201
raise NotImplementedError("Must be implemented by subclasses")
201202

203+
@contextlib.contextmanager
204+
def collect_sub_proto_messages(self) -> Iterator['MsgCollector']:
205+
"""
206+
Can be used to gather up all messages that are sent to the peer.
207+
"""
208+
if not self.is_running:
209+
raise RuntimeError("Cannot collect messages if peer is not running")
210+
msg_collector = MsgCollector()
211+
212+
with msg_collector.subscribe_peer(self):
213+
yield msg_collector
214+
202215
@property
203216
def received_msgs_count(self) -> int:
204217
return sum(self.received_msgs.values())
@@ -622,6 +635,17 @@ def subscribe_peer(self, peer: BasePeer) -> Iterator[None]:
622635
peer.remove_subscriber(self)
623636

624637

638+
class MsgCollector(PeerSubscriber):
639+
logger = logging.getLogger('p2p.peer.MsgCollector')
640+
msg_queue_maxsize = 200
641+
subscription_msg_types = {protocol.Command}
642+
643+
@to_tuple
644+
def get_messages(self) -> Iterator['PEER_MSG_TYPE']:
645+
while not self.msg_queue.empty():
646+
yield self.msg_queue.get_nowait()
647+
648+
625649
class PeerPool(BaseService, AsyncIterable[BasePeer]):
626650
"""
627651
PeerPool maintains connections to up-to max_peers on a given network.
@@ -676,21 +700,25 @@ def unsubscribe(self, subscriber: PeerSubscriber) -> None:
676700
peer.remove_subscriber(subscriber)
677701

678702
async def start_peer(self, peer: BasePeer) -> None:
703+
asyncio.ensure_future(peer.run(finished_callback=self._peer_finished))
704+
await self.wait(peer.events.started.wait(), timeout=1)
679705
try:
680706
# Although connect() may seem like a more appropriate place to perform the DAO fork
681707
# check, we do it here because we want to perform it for incoming peer connections as
682708
# well.
683-
msgs = await self.ensure_same_side_on_dao_fork(peer)
709+
with peer.collect_sub_proto_messages() as collector:
710+
await self.ensure_same_side_on_dao_fork(peer)
684711
except DAOForkCheckFailure as err:
685712
self.logger.debug("DAO fork check with %s failed: %s", peer, err)
686713
await peer.disconnect(DisconnectReason.useless_peer)
687714
return
688-
asyncio.ensure_future(peer.run(finished_callback=self._peer_finished))
689-
self._add_peer(peer, msgs)
715+
else:
716+
msgs = tuple((cmd, msg) for _, cmd, msg in collector.get_messages())
717+
self._add_peer(peer, msgs)
690718

691719
def _add_peer(self,
692720
peer: BasePeer,
693-
msgs: List[Tuple[protocol.Command, protocol._DecodedMsgType]]) -> None:
721+
msgs: Tuple[Tuple[protocol.Command, protocol._DecodedMsgType], ...]) -> None:
694722
"""Add the given peer to the pool.
695723
696724
Appart from adding it to our list of connected nodes and adding each of our subscriber's
@@ -768,7 +796,7 @@ async def connect_to_nodes(self, nodes: Iterator[Node]) -> None:
768796
await self.start_peer(peer)
769797

770798
async def ensure_same_side_on_dao_fork(
771-
self, peer: BasePeer) -> List[Tuple[protocol.Command, protocol._DecodedMsgType]]:
799+
self, peer: BasePeer) -> None:
772800
"""Ensure we're on the same side of the DAO fork as the given peer.
773801
774802
In order to do that we have to request the DAO fork block and its parent, but while we
@@ -785,25 +813,14 @@ async def ensure_same_side_on_dao_fork(
785813
break
786814

787815
start_block = vm_class.dao_fork_block_number - 1
788-
# TODO: This can be either an `ETHPeer` or an `LESPeer`. Will be
789-
# fixed once full awaitable request API is completed.
790-
try:
791-
class MsgBuffer(PeerSubscriber):
792-
logger = logging.getLogger('p2p.peer.MsgBuffer')
793-
msg_queue_maxsize = 200
794-
subscription_msg_types = {protocol.Command}
795-
796-
msg_buffer = MsgBuffer()
797-
798-
with msg_buffer.subscribe_peer(peer):
799-
headers = await peer.handler.get_block_headers( # type: ignore
800-
start_block,
801-
max_headers=2,
802-
reverse=False,
803-
timeout=CHAIN_SPLIT_CHECK_TIMEOUT,
804-
)
805816

806-
msgs = [msg_buffer.msg_queue.get_nowait()[1:] for _ in range(msg_buffer.queue_size)]
817+
try:
818+
headers = await peer.requests.get_block_headers( # type: ignore
819+
start_block,
820+
max_headers=2,
821+
reverse=False,
822+
timeout=CHAIN_SPLIT_CHECK_TIMEOUT,
823+
)
807824

808825
except (TimeoutError, PeerConnectionLost) as err:
809826
raise DAOForkCheckFailure(
@@ -820,15 +837,18 @@ class MsgBuffer(PeerSubscriber):
820837
"Invalid header response during DAO fork check: {}".format(err)
821838
) from err
822839

823-
parent, header = headers
840+
if len(headers) != 2:
841+
raise DAOForkCheckFailure(
842+
"Peer %s failed to return DAO fork check headers".format(peer)
843+
)
844+
else:
845+
parent, header = headers
824846

825847
try:
826848
vm_class.validate_header(header, parent, check_seal=True)
827849
except EthValidationError as err:
828850
raise DAOForkCheckFailure("Peer failed DAO fork check validation: {}".format(err))
829851

830-
return msgs
831-
832852
def _peer_finished(self, peer: BaseService) -> None:
833853
"""Remove the given peer from our list of connected nodes.
834854
This is passed as a callback to be called when a peer finishes.

p2p/service.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
from p2p.cancellable import CancellableMixin
1818

1919

20+
class ServiceEvents:
21+
def __init__(self) -> None:
22+
self.started = asyncio.Event()
23+
self.stopped = asyncio.Event()
24+
self.cleaned_up = asyncio.Event()
25+
self.cancelled = asyncio.Event()
26+
self.finished = asyncio.Event()
27+
28+
2029
class BaseService(ABC, CancellableMixin):
2130
logger: TraceLogger = None
2231
_child_services: List['BaseService']
@@ -32,7 +41,7 @@ def __init__(self, token: CancelToken=None, loop: asyncio.AbstractEventLoop = No
3241
TraceLogger, logging.getLogger(self.__module__ + '.' + self.__class__.__name__))
3342

3443
self._run_lock = asyncio.Lock()
35-
self.cleaned_up = asyncio.Event()
44+
self.events = ServiceEvents()
3645
self._child_services = []
3746

3847
self.loop = loop
@@ -58,6 +67,7 @@ async def run(
5867

5968
try:
6069
async with self._run_lock:
70+
self.events.started.set()
6171
await self._run()
6272
except OperationCancelled as e:
6373
self.logger.info("%s finished: %s", self, e)
@@ -66,6 +76,7 @@ async def run(
6676
finally:
6777
# Trigger our cancel token to ensure all pending asyncio tasks and background
6878
# coroutines started by this service exit cleanly.
79+
self.events.cancelled.set()
6980
self.cancel_token.trigger()
7081

7182
await self.cleanup()
@@ -77,6 +88,7 @@ async def run(
7788
# XXX: Only added to help debug https://github.com/ethereum/py-evm/issues/1023;
7889
# should be removed eventually.
7990
self.logger.warn("%s finished but had no finished_callback", self)
91+
self.events.finished.set()
8092
self.logger.debug("%s halted cleanly", self)
8193

8294
def run_child_service(self, child_service: 'BaseService') -> 'asyncio.Future[Any]':
@@ -95,11 +107,11 @@ async def cleanup(self) -> None:
95107
their cleanup.
96108
"""
97109
await asyncio.gather(*[
98-
child_service.cleaned_up.wait()
110+
child_service.events.cleaned_up.wait()
99111
for child_service in self._child_services],
100112
self._cleanup()
101113
)
102-
self.cleaned_up.set()
114+
self.events.cleaned_up.set()
103115

104116
async def cancel(self) -> None:
105117
"""Trigger the CancelToken and wait for the cleaned_up event to be set."""
@@ -110,10 +122,11 @@ async def cancel(self) -> None:
110122
raise RuntimeError("Cannot cancel a service that has not been started")
111123

112124
self.logger.debug("Cancelling %s", self)
125+
self.events.cancelled.set()
113126
self.cancel_token.trigger()
114127
try:
115128
await asyncio.wait_for(
116-
self.cleaned_up.wait(), timeout=self._wait_until_finished_timeout)
129+
self.events.cleaned_up.wait(), timeout=self._wait_until_finished_timeout)
117130
except asyncio.futures.TimeoutError:
118131
self.logger.info("Timed out waiting for %s to finish its cleanup, exiting anyway", self)
119132
else:
@@ -130,7 +143,10 @@ async def threadsafe_cancel(self) -> None:
130143
:param poll_period: how many seconds to wait in between each check for service cleanup
131144
"""
132145
asyncio.run_coroutine_threadsafe(self.cancel(), loop=self.loop)
133-
await asyncio.wait_for(self.cleaned_up.wait(), timeout=self._wait_until_finished_timeout)
146+
await asyncio.wait_for(
147+
self.events.cleaned_up.wait(),
148+
timeout=self._wait_until_finished_timeout,
149+
)
134150

135151
async def sleep(self, delay: float) -> None:
136152
"""Coroutine that completes after a given time (in seconds)."""

tests/trinity/core/p2p-proto/test_peer_block_header_request_and_response_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def send_headers():
7474
remote.sub_proto.send_block_headers(headers)
7575

7676
asyncio.ensure_future(send_headers())
77-
response = await peer.handler.get_block_headers(*params)
77+
response = await peer.requests.get_block_headers(*params)
7878

7979
assert len(response) == len(headers)
8080
for expected, actual in zip(headers, response):
@@ -96,12 +96,12 @@ async def test_les_peer_get_headers_round_trip(les_peer_and_remote,
9696
peer, remote = les_peer_and_remote
9797

9898
async def send_headers():
99-
request_id = peer.handler.get_block_headers.pending_request[0].request_id
99+
request_id = peer.requests.get_block_headers.pending_request[0].request_id
100100
remote.sub_proto.send_block_headers(headers, 0, request_id)
101101
await asyncio.sleep(0)
102102

103103
asyncio.ensure_future(send_headers())
104-
response = await peer.handler.get_block_headers(*params)
104+
response = await peer.requests.get_block_headers(*params)
105105

106106
assert len(response) == len(headers)
107107
for expected, actual in zip(headers, response):
@@ -121,7 +121,7 @@ async def send_responses():
121121
await asyncio.sleep(0)
122122

123123
asyncio.ensure_future(send_responses())
124-
response = await peer.handler.get_block_headers(0, 10, 0, False)
124+
response = await peer.requests.get_block_headers(0, 10, 0, False)
125125

126126
assert len(response) == len(headers)
127127
for expected, actual in zip(headers, response):
@@ -145,7 +145,7 @@ async def send_responses():
145145
await asyncio.sleep(0)
146146

147147
asyncio.ensure_future(send_responses())
148-
response = await peer.handler.get_block_headers(0, 5, 0, False)
148+
response = await peer.requests.get_block_headers(0, 5, 0, False)
149149

150150
assert len(response) == len(headers)
151151
for expected, actual in zip(headers, response):

trinity/nodes/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,11 @@ async def _cleanup(self) -> None:
151151
# Stop the this IPCServer-specific event loop, so that the IPCServer thread will exit
152152
self._ipc_loop.stop()
153153

154-
await asyncio.gather(*[service.cleaned_up.wait() for service in self._auxiliary_services])
154+
await asyncio.gather(*[
155+
service.events.cleaned_up.wait()
156+
for service
157+
in self._auxiliary_services
158+
])
155159

156160
def _make_new_loop_thread(self) -> asyncio.AbstractEventLoop:
157161
new_loop = asyncio.new_event_loop()

trinity/protocol/common/handlers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ async def _run(self) -> None:
3535
manager = getattr(self, attr)
3636
self.run_child_service(manager)
3737

38-
while not self.cancel_token.triggered:
39-
await self.sleep(0.01)
38+
await self.cancel_token.wait()
4039

4140
async def _cleanup(self) -> None:
4241
pass

trinity/protocol/common/managers.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
ReturnType = TypeVar('ReturnType')
3232

3333

34-
class BaseRequestManager(PeerSubscriber, BaseService, Generic[PeerClass, RequestType, ResponseType, ReturnType]):
34+
class BaseRequestManager(PeerSubscriber, BaseService, Generic[PeerClass, RequestType, ResponseType, ReturnType]): # noqa: E501
3535
#
3636
# PeerSubscriber
3737
#
@@ -55,10 +55,10 @@ def __init__(self, peer: PeerClass, token: CancelToken) -> None:
5555
# Service API
5656
#
5757
async def _run(self) -> None:
58-
self.logger.debug("Running %s for peer %s", self.__class__.__name__, self._peer)
58+
self.logger.debug("Launching %s for peer %s", self.__class__.__name__, self._peer)
5959

6060
with self.subscribe_peer(self._peer):
61-
while True:
61+
while self.is_running:
6262
peer, cmd, msg = await self.wait(
6363
self.msg_queue.get(), token=self.cancel_token)
6464
if peer != self._peer:
@@ -74,6 +74,9 @@ async def _cleanup(self) -> None:
7474

7575
def _handle_msg(self, msg: ResponseType) -> None:
7676
if self.pending_request is None:
77+
self.logger.debug(
78+
"Got unexpected %s message from %", self.response_msg_name, self._peer
79+
)
7780
return
7881

7982
request, future = self.pending_request
@@ -83,7 +86,7 @@ def _handle_msg(self, msg: ResponseType) -> None:
8386
except ValidationError as err:
8487
self.logger.debug(
8588
"Response validation failure for pending %s request from peer %s: %s",
86-
self._response_msg_type.__name__,
89+
self.response_msg_name,
8790
self._peer,
8891
err,
8992
)
@@ -109,6 +112,10 @@ def __call__(self) -> ReturnType: # type: ignore
109112
def _response_msg_type(self) -> Type[Command]:
110113
pass
111114

115+
@property
116+
def response_msg_name(self) -> str:
117+
return self._response_msg_type.__name__
118+
112119
@abstractmethod
113120
def _send_sub_proto_request(self, request: RequestType) -> None:
114121
pass
@@ -119,12 +126,12 @@ async def _wait_for_response(self,
119126
if self.pending_request is not None:
120127
self.logger.error(
121128
"Already waiting for response to %s for peer: %s",
122-
self._response_msg_type.__name__,
129+
self.response_msg_name,
123130
self._peer,
124131
)
125132
raise AlreadyWaiting(
126133
"Already waiting for response to {0} for peer: {1}".format(
127-
self._response_msg_type.__name__,
134+
self.response_msg_name,
128135
self._peer
129136
)
130137
)

trinity/protocol/eth/managers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from trinity.protocol.common.managers import (
16-
BaseRequestManager,
16+
BaseRequestManager as _BaseRequestManager,
1717
)
1818

1919
from .commands import BlockHeaders
@@ -23,7 +23,15 @@
2323
from .peer import ETHPeer # noqa: F401
2424

2525

26-
class GetBlockHeadersRequestManager(BaseRequestManager['ETHPeer', HeaderRequest, Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]):
26+
BaseRequestManager = _BaseRequestManager[
27+
'ETHPeer',
28+
HeaderRequest,
29+
Tuple[BlockHeader, ...],
30+
Tuple[BlockHeader, ...],
31+
]
32+
33+
34+
class GetBlockHeadersRequestManager(BaseRequestManager):
2735
msg_queue_maxsize = 100
2836

2937
_response_msg_type: Type[Command] = BlockHeaders

0 commit comments

Comments
 (0)