Skip to content

Commit 2558d6c

Browse files
committed
Implement BaseService.run_task(awaitable)
It runs the given awaitable in the background, catches/ignores OperationCancelled and ensures the task is finished before the service is finished. Also removes trinity.Node.add_service() as the same functionality is provided by BaseService.run_child_service() Closes: #1118
1 parent 5534db7 commit 2558d6c

File tree

13 files changed

+73
-106
lines changed

13 files changed

+73
-106
lines changed

p2p/DEVELOPMENT.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,17 @@ library.
1717

1818
## BaseService
1919

20-
- If your service runs coroutines in the background (e.g. via `asyncio.ensure_future`), you must
21-
ensure they exit when `is_running` is False or when the cancel token is triggered
20+
- If your service needs to run coroutines in the background, you should use the `BaseService.run_task()` method and
21+
ensure they exit when `is_running` is False or when the cancel token is triggered.
2222
- If your service runs other services in the background, you should pass your CancelToken down to
23-
those services and ensure your `_cleanup()` waits for them to cleanup as well
23+
those services and run those using `BaseService.run_child_service()`.
2424

2525
```Python
2626
class Node(BaseService):
2727
async def _run(self):
2828
self.discovery = DiscoveryService(token=self.cancel_token)
29-
asyncio.ensure_future(self.discovery.run())
29+
self.run_child_service(self.discovery)
30+
self.run_task(self.discovery.bootstrap())
3031
# Node's run logic goes here...
3132

32-
async def _cleanup(self):
33-
await self.discovery.cleaned_up.wait()
3433
```

p2p/discovery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,18 +335,18 @@ def __init__(
335335

336336
async def _run(self) -> None:
337337
connect_loop_sleep = 2
338-
asyncio.ensure_future(self.proto.bootstrap())
338+
self.run_task(self.proto.bootstrap())
339339
while True:
340340
await self.maybe_connect_to_more_peers()
341-
await self.wait(asyncio.sleep(connect_loop_sleep))
341+
await self.sleep(connect_loop_sleep)
342342

343343
async def maybe_connect_to_more_peers(self) -> None:
344344
"""Connect to more peers if we're not yet maxed out to max_peers"""
345345
if self.peer_pool.is_full:
346346
self.logger.debug("Already connected to %s peers; sleeping", len(self.peer_pool))
347347
return
348348

349-
asyncio.ensure_future(self.maybe_lookup_random_node())
349+
self.run_task(self.maybe_lookup_random_node())
350350

351351
await self.peer_pool.connect_to_nodes(
352352
self.proto.get_nodes_to_connect(self.peer_pool.max_peers))

p2p/peer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def unsubscribe(self, subscriber: PeerSubscriber) -> None:
720720
peer.remove_subscriber(subscriber)
721721

722722
async def start_peer(self, peer: BasePeer) -> None:
723-
asyncio.ensure_future(peer.run())
723+
self.run_task(peer.run())
724724
await self.wait(peer.events.started.wait(), timeout=1)
725725
try:
726726
# Although connect() may seem like a more appropriate place to perform the DAO fork
@@ -756,7 +756,7 @@ def _add_peer(self,
756756
async def _run(self) -> None:
757757
# FIXME: PeerPool should probably no longer be a BaseService, but for now we're keeping it
758758
# so in order to ensure we cancel all peers when we terminate.
759-
asyncio.ensure_future(self._periodically_report_stats())
759+
self.run_task(self._periodically_report_stats())
760760
await self.cancel_token.wait()
761761

762762
async def stop_all_peers(self) -> None:
@@ -933,7 +933,7 @@ async def _periodically_report_stats(self) -> None:
933933
self.logger.debug(" %s", line)
934934
self.logger.debug("== End peer details == ")
935935
try:
936-
await self.wait(asyncio.sleep(self._report_interval))
936+
await self.sleep(self._report_interval)
937937
except OperationCancelled:
938938
break
939939

@@ -1053,7 +1053,8 @@ def _test() -> None:
10531053
ROPSTEN_VM_CONFIGURATION,
10541054
)
10551055

1056-
asyncio.ensure_future(connect_to_peers_loop(peer_pool, nodes))
1056+
asyncio.ensure_future(peer_pool.run())
1057+
peer_pool.run_task(connect_to_peers_loop(peer_pool, nodes))
10571058

10581059
async def request_stuff() -> None:
10591060
# Request some stuff from ropsten's block 2440319
@@ -1086,7 +1087,6 @@ async def exit_on_sigint() -> None:
10861087

10871088
asyncio.ensure_future(exit_on_sigint())
10881089
asyncio.ensure_future(request_stuff())
1089-
asyncio.ensure_future(peer_pool.run())
10901090
loop.set_debug(True)
10911091
loop.run_forever()
10921092
loop.close()

p2p/service.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import logging
55
from typing import (
66
Any,
7+
Awaitable,
78
Callable,
89
List,
910
Optional,
1011
cast,
1112
)
13+
from weakref import WeakSet
1214

1315
from eth.tools.logging import TraceLogger
1416

@@ -30,6 +32,8 @@ def __init__(self) -> None:
3032
class BaseService(ABC, CancellableMixin):
3133
logger: TraceLogger = None
3234
_child_services: List['BaseService']
35+
# Use a WeakSet so that we don't have to bother updating it when tasks finish.
36+
_tasks: 'WeakSet[asyncio.Future[Any]]'
3337
_finished_callbacks: List[Callable[['BaseService'], None]]
3438
# Number of seconds cancel() will wait for run() to finish.
3539
_wait_until_finished_timeout = 5
@@ -45,6 +49,7 @@ def __init__(self,
4549
self.events = ServiceEvents()
4650
self._run_lock = asyncio.Lock()
4751
self._child_services = []
52+
self._tasks = WeakSet()
4853
self._finished_callbacks = []
4954

5055
self._loop = loop
@@ -114,12 +119,26 @@ async def run(
114119
def add_finished_callback(self, finished_callback: Callable[['BaseService'], None]) -> None:
115120
self._finished_callbacks.append(finished_callback)
116121

117-
def run_child_service(self, child_service: 'BaseService') -> 'asyncio.Future[Any]':
122+
def run_task(self, awaitable: Awaitable[Any]) -> None:
123+
"""Run the given awaitable in the background.
124+
125+
The awaitable should return whenever this service's cancel token is triggered.
126+
127+
If it raises OperationCancelled, that is caught and ignored.
128+
"""
129+
async def f() -> None:
130+
try:
131+
await awaitable
132+
except OperationCancelled:
133+
pass
134+
self._tasks.add(asyncio.ensure_future(f()))
135+
136+
def run_child_service(self, child_service: 'BaseService') -> None:
118137
"""
119138
Run a child service and keep a reference to it to be considered during the cleanup.
120139
"""
121140
self._child_services.append(child_service)
122-
return asyncio.ensure_future(child_service.run())
141+
self.run_task(child_service.run())
123142

124143
async def _run_in_executor(self, callback: Callable[..., Any], *args: Any) -> Any:
125144
loop = self.get_event_loop()
@@ -136,6 +155,7 @@ async def cleanup(self) -> None:
136155
await asyncio.gather(*[
137156
child_service.events.cleaned_up.wait()
138157
for child_service in self._child_services],
158+
*[task for task in self._tasks],
139159
self._cleanup()
140160
)
141161
self.events.cleaned_up.set()
@@ -155,10 +175,20 @@ async def cancel(self) -> None:
155175
await asyncio.wait_for(
156176
self.events.cleaned_up.wait(), timeout=self._wait_until_finished_timeout)
157177
except asyncio.futures.TimeoutError:
158-
self.logger.info("Timed out waiting for %s to finish its cleanup, exiting anyway", self)
178+
self.logger.info(
179+
"Timed out waiting for %s to finish its cleanup, forcibly cancelling pending "
180+
"tasks and exiting anyway", self)
181+
self._forcibly_cancel_all_tasks()
182+
# Sleep a bit because the Future.cancel() method just schedules the callbacks, so we
183+
# need to give the event loop a chance to actually call them.
184+
await asyncio.sleep(0.5)
159185
else:
160186
self.logger.debug("%s finished cleanly", self)
161187

188+
def _forcibly_cancel_all_tasks(self) -> None:
189+
for task in self._tasks:
190+
task.cancel()
191+
162192
@property
163193
def is_running(self) -> bool:
164194
return self._run_lock.locked()
@@ -187,13 +217,12 @@ async def _run(self) -> None:
187217
"""
188218
raise NotImplementedError()
189219

190-
@abstractmethod
191220
async def _cleanup(self) -> None:
192221
"""Clean up any resources held by this service.
193222
194223
Called after the service's _run() method returns.
195224
"""
196-
raise NotImplementedError()
225+
pass
197226

198227

199228
def service_timeout(timeout: int) -> Callable[..., Any]:

trinity/nodes/base.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
)
77
from threading import Thread
88
from typing import (
9-
List,
109
Type,
1110
)
1211

@@ -62,7 +61,6 @@ def __init__(self, plugin_manager: PluginManager, chain_config: ChainConfig) ->
6261
self._headerdb = self._db_manager.get_headerdb() # type: ignore
6362

6463
self._jsonrpc_ipc_path: Path = chain_config.jsonrpc_ipc_path
65-
self._auxiliary_services: List[BaseService] = []
6664

6765
@abstractmethod
6866
def get_chain(self) -> BaseChain:
@@ -91,12 +89,6 @@ def db_manager(self) -> BaseManager:
9189
def headerdb(self) -> BaseAsyncHeaderDB:
9290
return self._headerdb
9391

94-
def add_service(self, service: BaseService) -> None:
95-
if self.is_running:
96-
raise RuntimeError("Cannot add an auxiliary service while the node is running")
97-
else:
98-
self._auxiliary_services.append(service)
99-
10092
def notify_resource_available(self) -> None:
10193

10294
# We currently need this to give plugins the chance to start as soon
@@ -139,9 +131,6 @@ async def _run(self) -> None:
139131

140132
asyncio.run_coroutine_threadsafe(self._ipc_server.run(), loop=ipc_loop)
141133

142-
for service in self._auxiliary_services:
143-
asyncio.ensure_future(service.run())
144-
145134
await self.get_p2p_server().run()
146135

147136
async def _cleanup(self) -> None:
@@ -151,12 +140,6 @@ async def _cleanup(self) -> None:
151140
# Stop the this IPCServer-specific event loop, so that the IPCServer thread will exit
152141
self._ipc_loop.stop()
153142

154-
await asyncio.gather(*[
155-
service.events.cleaned_up.wait()
156-
for service
157-
in self._auxiliary_services
158-
])
159-
160143
def _make_new_loop_thread(self) -> asyncio.AbstractEventLoop:
161144
new_loop = asyncio.new_event_loop()
162145

trinity/nodes/light.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ def __init__(self, plugin_manager: PluginManager, chain_config: ChainConfig) ->
5151
self._discovery = DiscoveryService(
5252
self._discovery_proto, self._peer_pool, self.cancel_token)
5353
self._peer_chain = LightPeerChain(self.headerdb, self._peer_pool, self.cancel_token)
54-
self.add_service(self._discovery)
55-
self.add_service(self._peer_pool)
56-
self.add_service(self._peer_chain)
5754
self.notify_resource_available()
5855

5956
async def _run(self) -> None:
60-
# TODO add a datagram endpoint service that can be added with self.add_service
57+
self.run_child_service(self._discovery)
58+
self.run_child_service(self._peer_pool)
59+
self.run_child_service(self._peer_chain)
60+
# TODO add a datagram endpoint service that can be added with self.run_child_service
6161
self.logger.info(
6262
"enode://%s@%s:%s",
6363
self.nodekey.public_key.to_hex()[2:],

trinity/protocol/common/handlers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ async def _run(self) -> None:
3838

3939
await self.cancel_token.wait()
4040

41-
async def _cleanup(self) -> None:
42-
pass
43-
4441
def get_stats(self) -> List[str]:
4542
manager_attrs = self._managers.keys()
4643
return [getattr(self, attr).get_stats() for attr in manager_attrs]

trinity/protocol/common/managers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ async def _run(self) -> None:
102102
else:
103103
self.logger.warning("Unexpected message type: %s", cmd.__class__.__name__)
104104

105-
async def _cleanup(self) -> None:
106-
pass
107-
108105
async def _handle_msg(self, msg: TMsg) -> None:
109106
if self.pending_request is None:
110107
self.logger.debug(

trinity/sync/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _test() -> None:
8282
nodes = DEFAULT_PREFERRED_NODES[network_id]
8383

8484
asyncio.ensure_future(peer_pool.run())
85-
asyncio.ensure_future(connect_to_peers_loop(peer_pool, nodes))
85+
peer_pool.run_task(connect_to_peers_loop(peer_pool, nodes))
8686
chain = chain_class(base_db)
8787
syncer: BaseHeaderChainSyncer = None
8888
if args.fast:

trinity/sync/common/chain.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,30 +76,25 @@ def register_peer(self, peer: BasePeer) -> None:
7676

7777
async def _handle_msg_loop(self) -> None:
7878
while self.is_running:
79-
try:
80-
peer, cmd, msg = await self.wait(self.msg_queue.get())
81-
except OperationCancelled:
82-
break
83-
79+
peer, cmd, msg = await self.wait(self.msg_queue.get())
8480
# Our handle_msg() method runs cpu-intensive tasks in sub-processes so that the main
85-
# loop can keep processing msgs, and that's why we use ensure_future() instead of
81+
# loop can keep processing msgs, and that's why we use self.run_task() instead of
8682
# awaiting for it to finish here.
87-
asyncio.ensure_future(self.handle_msg(cast(HeaderRequestingPeer, peer), cmd, msg))
83+
self.run_task(self.handle_msg(cast(HeaderRequestingPeer, peer), cmd, msg))
8884

8985
async def handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
9086
msg: protocol._DecodedMsgType) -> None:
9187
try:
9288
await self._handle_msg(peer, cmd, msg)
9389
except OperationCancelled:
94-
# Silently swallow OperationCancelled exceptions because we run unsupervised (i.e.
95-
# with ensure_future()). Our caller will also get an OperationCancelled anyway, and
96-
# there it will be handled.
90+
# Silently swallow OperationCancelled exceptions because otherwise they'll be caught
91+
# by the except below and treated as unexpected.
9792
pass
9893
except Exception:
9994
self.logger.exception("Unexpected error when processing msg from %s", peer)
10095

10196
async def _run(self) -> None:
102-
asyncio.ensure_future(self._handle_msg_loop())
97+
self.run_task(self._handle_msg_loop())
10398
with self.subscribe(self.peer_pool):
10499
while True:
105100
peer_or_finished: Any = await self.wait_first(
@@ -114,12 +109,7 @@ async def _run(self) -> None:
114109

115110
# Since self._sync_complete is not set, peer_or_finished can only be a Peer
116111
# instance.
117-
asyncio.ensure_future(self.sync(peer_or_finished))
118-
119-
async def _cleanup(self) -> None:
120-
# We don't need to cancel() anything, but we yield control just so that the coroutines we
121-
# run in the background notice the cancel token has been triggered and return.
122-
await asyncio.sleep(0)
112+
self.run_task(self.sync(peer_or_finished))
123113

124114
async def sync(self, peer: HeaderRequestingPeer) -> None:
125115
if self._syncing:

0 commit comments

Comments
 (0)