|
3 | 3 | # Distributed under the terms of the Modified BSD License.
|
4 | 4 | from __future__ import print_function
|
5 | 5 |
|
6 |
| -import threading |
7 |
| - |
8 | 6 | try:
|
9 | 7 | from collections.abc import Iterable
|
10 | 8 | except ImportError: # py2
|
@@ -367,6 +365,7 @@ def _profile_default(self):
|
367 | 365 | _mux_socket = Instance('zmq.Socket', allow_none=True)
|
368 | 366 | _task_socket = Instance('zmq.Socket', allow_none=True)
|
369 | 367 | _broadcast_socket = Instance('zmq.Socket', allow_none=True)
|
| 368 | + _registration_callbacks = List() |
370 | 369 |
|
371 | 370 | _task_scheme = Unicode()
|
372 | 371 | _closed = False
|
@@ -778,6 +777,8 @@ def _register_engine(self, msg):
|
778 | 777 | eid = content['id']
|
779 | 778 | d = {eid: content['uuid']}
|
780 | 779 | self._update_engines(d)
|
| 780 | + for callback in self._registration_callbacks: |
| 781 | + callback(content) |
781 | 782 |
|
782 | 783 | def _unregister_engine(self, msg):
|
783 | 784 | """Unregister an engine that has died."""
|
@@ -1290,6 +1291,80 @@ def _futures_for_msgs(self, msg_ids):
|
1290 | 1291 | futures.append(f)
|
1291 | 1292 | return futures
|
1292 | 1293 |
|
| 1294 | + def wait_for_engines(self, n, *, timeout=-1, block=True): |
| 1295 | + """Wait for `n` engines to become available. |
| 1296 | +
|
| 1297 | + Returns when `n` engines are available, |
| 1298 | + or raises TimeoutError if `timeout` is reached |
| 1299 | + before `n` engines are ready. |
| 1300 | +
|
| 1301 | + Parameters |
| 1302 | + ---------- |
| 1303 | + n : int |
| 1304 | + Number of engines to wait for. |
| 1305 | + timeout : float |
| 1306 | + Time (in seconds) to wait before raising a TimeoutError |
| 1307 | + block : bool |
| 1308 | + if False, return Future instead of waiting |
| 1309 | +
|
| 1310 | + Returns |
| 1311 | + ------ |
| 1312 | + f : concurrent.futures.Future or None |
| 1313 | + Future object to wait on if block is False, |
| 1314 | + None if block is True. |
| 1315 | +
|
| 1316 | + Raises |
| 1317 | + ------ |
| 1318 | + TimeoutError : if timeout is reached. |
| 1319 | + """ |
| 1320 | + if len(self.ids) >= n: |
| 1321 | + return |
| 1322 | + tic = now = time.perf_counter() |
| 1323 | + if timeout >= 0: |
| 1324 | + deadline = tic + timeout |
| 1325 | + else: |
| 1326 | + deadline = None |
| 1327 | + seconds_remaining = 1000 |
| 1328 | + |
| 1329 | + future = Future() |
| 1330 | + |
| 1331 | + def notify(_): |
| 1332 | + if future.done(): |
| 1333 | + return |
| 1334 | + if len(self.ids) >= n: |
| 1335 | + future.set_result(None) |
| 1336 | + |
| 1337 | + future.add_done_callback(lambda f: self._registration_callbacks.remove(notify)) |
| 1338 | + self._registration_callbacks.append(notify) |
| 1339 | + |
| 1340 | + def on_timeout(): |
| 1341 | + """Called when timeout is reached""" |
| 1342 | + if future.done(): |
| 1343 | + return |
| 1344 | + |
| 1345 | + if len(self.ids) >= n: |
| 1346 | + future.set_result(None) |
| 1347 | + else: |
| 1348 | + future.set_exception( |
| 1349 | + TimeoutError( |
| 1350 | + "{n} engines not ready in {timeout} seconds. Currently ready: {len(self.ids)}" |
| 1351 | + ) |
| 1352 | + ) |
| 1353 | + |
| 1354 | + def schedule_timeout(): |
| 1355 | + handle = self._io_loop.add_timeout( |
| 1356 | + self._io_loop.time() + timeout, on_timeout |
| 1357 | + ) |
| 1358 | + future.add_done_callback(lambda f: self._io_loop.remove_timeout(handle)) |
| 1359 | + |
| 1360 | + if timeout >= 0: |
| 1361 | + self._io_loop.add_callback(schedule_timeout) |
| 1362 | + |
| 1363 | + if block: |
| 1364 | + return future.result() |
| 1365 | + else: |
| 1366 | + return future |
| 1367 | + |
1293 | 1368 | def wait(self, jobs=None, timeout=-1):
|
1294 | 1369 | """waits on one or more `jobs`, for up to `timeout` seconds.
|
1295 | 1370 |
|
|
0 commit comments