Skip to content

Commit b1e93bb

Browse files
committed
add Client.wait_for_engines(n)
handy method to wait for a number of engines to be ready
1 parent 854ff2e commit b1e93bb

File tree

3 files changed

+88
-5
lines changed

3 files changed

+88
-5
lines changed

ipyparallel/client/client.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# Distributed under the terms of the Modified BSD License.
44
from __future__ import print_function
55

6-
import threading
7-
86
try:
97
from collections.abc import Iterable
108
except ImportError: # py2
@@ -367,6 +365,7 @@ def _profile_default(self):
367365
_mux_socket = Instance('zmq.Socket', allow_none=True)
368366
_task_socket = Instance('zmq.Socket', allow_none=True)
369367
_broadcast_socket = Instance('zmq.Socket', allow_none=True)
368+
_registration_callbacks = List()
370369

371370
_task_scheme = Unicode()
372371
_closed = False
@@ -778,6 +777,8 @@ def _register_engine(self, msg):
778777
eid = content['id']
779778
d = {eid: content['uuid']}
780779
self._update_engines(d)
780+
for callback in self._registration_callbacks:
781+
callback(content)
781782

782783
def _unregister_engine(self, msg):
783784
"""Unregister an engine that has died."""
@@ -1290,6 +1291,80 @@ def _futures_for_msgs(self, msg_ids):
12901291
futures.append(f)
12911292
return futures
12921293

1294+
def wait_for_engines(self, n, *, timeout=-1, block=True):
1295+
"""Wait for `n` engines to become available.
1296+
1297+
Returns when `n` engines are available,
1298+
or raises TimeoutError if `timeout` is reached
1299+
before `n` engines are ready.
1300+
1301+
Parameters
1302+
----------
1303+
n : int
1304+
Number of engines to wait for.
1305+
timeout : float
1306+
Time (in seconds) to wait before raising a TimeoutError
1307+
block : bool
1308+
if False, return Future instead of waiting
1309+
1310+
Returns
1311+
------
1312+
f : concurrent.futures.Future or None
1313+
Future object to wait on if block is False,
1314+
None if block is True.
1315+
1316+
Raises
1317+
------
1318+
TimeoutError : if timeout is reached.
1319+
"""
1320+
if len(self.ids) >= n:
1321+
return
1322+
tic = now = time.perf_counter()
1323+
if timeout >= 0:
1324+
deadline = tic + timeout
1325+
else:
1326+
deadline = None
1327+
seconds_remaining = 1000
1328+
1329+
future = Future()
1330+
1331+
def notify(_):
1332+
if future.done():
1333+
return
1334+
if len(self.ids) >= n:
1335+
future.set_result(None)
1336+
1337+
future.add_done_callback(lambda f: self._registration_callbacks.remove(notify))
1338+
self._registration_callbacks.append(notify)
1339+
1340+
def on_timeout():
1341+
"""Called when timeout is reached"""
1342+
if future.done():
1343+
return
1344+
1345+
if len(self.ids) >= n:
1346+
future.set_result(None)
1347+
else:
1348+
future.set_exception(
1349+
TimeoutError(
1350+
"{n} engines not ready in {timeout} seconds. Currently ready: {len(self.ids)}"
1351+
)
1352+
)
1353+
1354+
def schedule_timeout():
1355+
handle = self._io_loop.add_timeout(
1356+
self._io_loop.time() + timeout, on_timeout
1357+
)
1358+
future.add_done_callback(lambda f: self._io_loop.remove_timeout(handle))
1359+
1360+
if timeout >= 0:
1361+
self._io_loop.add_callback(schedule_timeout)
1362+
1363+
if block:
1364+
return future.result()
1365+
else:
1366+
return future
1367+
12931368
def wait(self, jobs=None, timeout=-1):
12941369
"""waits on one or more `jobs`, for up to `timeout` seconds.
12951370

ipyparallel/tests/clienttest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def minimum_engines(self, n=1, block=True):
108108
def wait_on_engines(self, timeout=5):
109109
"""wait for our engines to connect."""
110110
n = len(self.engines) + self.base_engine_count
111-
tic = time.time()
112-
while time.time() - tic < timeout and len(self.client.ids) < n:
113-
time.sleep(0.1)
111+
self.client.wait_for_engines(n, timeout=timeout)
114112

115113
assert not len(self.client.ids) < n, "waiting for engines timed out"
116114

ipyparallel/tests/test_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,3 +641,13 @@ def test_local_ip_true_doesnt_trigger_warning(self):
641641
]
642642
assert len(runtime_warnings) == 0, str([str(w) for w in runtime_warnings])
643643
c.close()
644+
645+
def test_wait_for_engines(self):
646+
n = len(self.client)
647+
assert self.client.wait_for_engines(n) is None
648+
with pytest.raises(TimeoutError):
649+
self.client.wait_for_engines(n + 1, timeout=0.1)
650+
651+
f = self.client.wait_for_engines(n + 1, timeout=10, block=False)
652+
self.add_engines(1)
653+
assert f.result() is None

0 commit comments

Comments
 (0)