Skip to content

Commit 6b31362

Browse files
authored
Merge pull request #487 from minrk/tqdm-notebook
add some more control to interactive waits
2 parents 508af10 + e5cd966 commit 6b31362

File tree

3 files changed

+81
-21
lines changed

3 files changed

+81
-21
lines changed

ipyparallel/client/asyncresult.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,23 @@
1010
from contextlib import contextmanager
1111
from datetime import datetime
1212
from functools import partial
13+
from queue import Queue
1314
from threading import Event
1415

15-
try:
16-
from queue import Queue
17-
except ImportError: # py2
18-
from Queue import Queue
19-
20-
from decorator import decorator
21-
import tqdm
2216
import zmq
23-
from zmq import MessageTracker
24-
17+
from decorator import decorator
2518
from IPython import get_ipython
26-
from IPython.core.display import display, display_pretty
27-
from ipyparallel import error
28-
from ipyparallel.util import utcnow, compare_datetimes, _parse_date
19+
from IPython.core.display import display
20+
from IPython.core.display import display_pretty
2921
from ipython_genutils.py3compat import string_types
3022

31-
from .futures import MessageFuture, multi_future
23+
from .futures import MessageFuture
24+
from .futures import multi_future
25+
from ipyparallel import error
26+
from ipyparallel.util import _parse_date
27+
from ipyparallel.util import compare_datetimes
28+
from ipyparallel.util import progress
29+
from ipyparallel.util import utcnow
3230

3331

3432
def _raw_text(s):
@@ -38,7 +36,7 @@ def _raw_text(s):
3836
_default = object()
3937

4038
# global empty tracker that's always done:
41-
finished_tracker = MessageTracker()
39+
finished_tracker = zmq.MessageTracker()
4240

4341

4442
@decorator
@@ -75,7 +73,6 @@ def __init__(
7573
fname='unknown',
7674
targets=None,
7775
owner=False,
78-
progress_bar=tqdm.tqdm,
7976
):
8077
super(AsyncResult, self).__init__()
8178
if not isinstance(children, list):
@@ -95,7 +92,6 @@ def __init__(
9592
self._fname = fname
9693
self._targets = targets
9794
self.owner = owner
98-
self.progress_bar = progress_bar
9995

10096
self._ready = False
10197
self._ready_event = Event()
@@ -381,7 +377,7 @@ def _handle_sent(self, f):
381377
"""Resolve sent Future, build MessageTracker"""
382378
trackers = f.result()
383379
trackers = [t for t in trackers if t is not None]
384-
self._tracker = MessageTracker(*trackers)
380+
self._tracker = zmq.MessageTracker(*trackers)
385381
self._sent_event.set()
386382

387383
@property
@@ -569,13 +565,27 @@ def wall_time(self):
569565
"""
570566
return self.timedelta(self.submitted, self.received)
571567

572-
def wait_interactive(self, interval=1.0, timeout=-1):
573-
"""interactive wait, printing progress at regular intervals."""
568+
def wait_interactive(self, interval=0.1, timeout=-1, widget=None):
569+
"""interactive wait, printing progress at regular intervals.
570+
571+
Parameters
572+
----------
573+
interval : float
574+
Interval on which to update progress display.
575+
timeout : float
576+
Time (in seconds) to wait before raising a TimeoutError.
577+
-1 (default) means no timeout.
578+
widget : bool
579+
default: True if in an IPython kernel (notebook), False otherwise.
580+
Override default context-detection behavior for whether a widget-based progress bar
581+
should be used.
582+
"""
574583
if timeout is None:
575584
timeout = -1
576585
N = len(self)
577586
tic = time.perf_counter()
578-
progress_bar = self.progress_bar(total=N, unit='tasks', desc=self._fname)
587+
progress_bar = progress(widget=widget, total=N, unit='tasks', desc=self._fname)
588+
579589
n_prev = 0
580590
while not self.ready() and (
581591
timeout < 0 or time.perf_counter() - tic <= timeout

ipyparallel/client/client.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1307,7 +1307,9 @@ def _futures_for_msgs(self, msg_ids):
13071307
futures.append(f)
13081308
return futures
13091309

1310-
def wait_for_engines(self, n, *, timeout=-1, block=True):
1310+
def wait_for_engines(
1311+
self, n, *, timeout=-1, block=True, interactive=None, widget=None
1312+
):
13111313
"""Wait for `n` engines to become available.
13121314
13131315
Returns when `n` engines are available,
@@ -1322,6 +1324,14 @@ def wait_for_engines(self, n, *, timeout=-1, block=True):
13221324
Time (in seconds) to wait before raising a TimeoutError
13231325
block : bool
13241326
if False, return Future instead of waiting
1327+
interactive : bool
1328+
default: True if in IPython, False otherwise.
1329+
if True, show a progress bar while waiting for engines
1330+
widget : bool
1331+
default: True if in an IPython kernel (notebook), False otherwise.
1332+
Only has an effect if `interactive` is True.
1333+
if True, forces use of widget progress bar.
1334+
If False, forces use of terminal tqdm.
13251335
13261336
Returns
13271337
------
@@ -1347,12 +1357,25 @@ def wait_for_engines(self, n, *, timeout=-1, block=True):
13471357
deadline = None
13481358
seconds_remaining = 1000
13491359

1360+
if interactive is None:
1361+
interactive = get_ipython() is not None
1362+
1363+
if interactive:
1364+
progress_bar = util.progress(
1365+
widget=widget, initial=len(self.ids), total=n, unit='engine'
1366+
)
1367+
13501368
future = Future()
13511369

13521370
def notify(_):
13531371
if future.done():
13541372
return
1373+
if interactive:
1374+
progress_bar.update(len(self.ids) - progress_bar.n)
13551375
if len(self.ids) >= n:
1376+
# ensure we refresh when we finish
1377+
if interactive:
1378+
progress_bar.close()
13561379
future.set_result(None)
13571380

13581381
future.add_done_callback(lambda f: self._registration_callbacks.remove(notify))

ipyparallel/util.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def lru_cache(maxsize=0):
3030
SIGKILL = None
3131
from types import FunctionType
3232

33+
import tqdm
3334
from dateutil.parser import parse as dateutil_parse
3435
from dateutil.tz import tzlocal
3536

@@ -636,3 +637,29 @@ def _patch_jupyter_client_dates():
636637

637638
# FIXME: remove patch when we require jupyter_client 5.0
638639
_patch_jupyter_client_dates()
640+
641+
642+
def progress(*args, widget=None, **kwargs):
643+
"""Create a tqdm progress bar
644+
645+
If `widget` is None, autodetects if IPython widgets should be used,
646+
otherwise use basic tqdm.
647+
"""
648+
if widget is None:
649+
# auto widget if in a kernel
650+
ip = get_ipython()
651+
if ip is not None and getattr(ip, 'kernel', None) is not None:
652+
try:
653+
import ipywidgets # noqa
654+
except ImportError:
655+
widget = False
656+
else:
657+
widget = True
658+
else:
659+
widget = False
660+
if widget:
661+
f = tqdm.tqdm_notebook
662+
else:
663+
kwargs.setdefault("file", sys.stdout)
664+
f = tqdm.tqdm
665+
return f(*args, **kwargs)

0 commit comments

Comments
 (0)