Skip to content

Commit 6667078

Browse files
committed
Capture print statements and show them in notebooks with ipywidgets.
This seems to work well in Jupyteer Notebook, but not Jupyter Lab. Also, unlike before, stderr isn't streamed (how does that work btw?). So, all output is sent back to the client upon the work being completed.
1 parent f11ad7a commit 6667078

File tree

2 files changed

+117
-37
lines changed

2 files changed

+117
-37
lines changed

afar/core.py

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,26 @@
33
import inspect
44
import io
55
import sys
6+
import threading
67
import weakref
8+
from functools import partial
79

810
import innerscope
911
from dask import distributed
1012

1113
from . import reprs
1214

15+
16+
def _supports_async_output():
17+
if reprs.is_kernel() and not reprs.in_terminal():
18+
try:
19+
import ipywidgets # noqa
20+
except ImportError:
21+
return False
22+
return True
23+
return False
24+
25+
1326
_errors_to_locations = {}
1427
try:
1528
remotely
@@ -137,11 +150,13 @@ def __enter__(self):
137150
# Try to fine the source if we are in %%time or %%timeit magic
138151
if (
139152
self._frame.f_code.co_filename in {"<timed exec>", "<magic-timeit>"}
140-
and reprs.in_ipython()
153+
and reprs.is_kernel()
141154
):
142-
import IPython
155+
from IPython import get_ipython
143156

144-
ip = IPython.get_ipython()
157+
ip = get_ipython()
158+
if ip is None:
159+
raise
145160
cell = ip.history_manager._i00 # The current cell!
146161
lines = cell.splitlines(keepends=True)
147162
# strip the magic
@@ -233,6 +248,7 @@ def _exit(self, exc_type, exc_value, exc_traceback):
233248
self.context_body = get_body(self._lines[self._body_start : endline])
234249
self._magic_func, names, futures = abracadabra(self)
235250
display_expr = self._magic_func._display_expr
251+
has_print = "print" in self._magic_func._scoped.builtin_names
236252

237253
if self._where == "remotely":
238254
if client is None:
@@ -268,16 +284,19 @@ def _exit(self, exc_type, exc_value, exc_traceback):
268284
weak_futures.add(remote_dict)
269285
magic_func.release() # Let go ASAP
270286
if display_expr:
271-
repr_val = client.submit(
287+
repr_future = client.submit(
272288
reprs.repr_afar,
273289
client.submit(get_afar, remote_dict, "_afar_return_value_"),
274290
self._magic_func._repr_methods,
275291
)
276-
weak_futures.add(repr_val)
277-
stdout_future = client.submit(get_afar, remote_dict, "_afar_stdout_")
278-
weak_futures.add(stdout_future)
279-
stderr_future = client.submit(get_afar, remote_dict, "_afar_stderr_")
280-
weak_futures.add(stderr_future)
292+
weak_futures.add(repr_future)
293+
else:
294+
repr_future = None
295+
if display_expr or has_print or _supports_async_output():
296+
stdout_future = client.submit(get_afar, remote_dict, "_afar_stdout_")
297+
weak_futures.add(stdout_future)
298+
stderr_future = client.submit(get_afar, remote_dict, "_afar_stderr_")
299+
weak_futures.add(stderr_future)
281300
if self._gather_data:
282301
futures_to_name = {
283302
client.submit(get_afar, remote_dict, name, **submit_kwargs): name
@@ -294,25 +313,41 @@ def _exit(self, exc_type, exc_value, exc_traceback):
294313
self.data[name] = future
295314
remote_dict.release() # Let go ASAP
296315

297-
# blocks!
298-
stdout_val = stdout_future.result()
299-
if stdout_val:
300-
print(stdout_val, end="")
301-
stdout_future.release()
302-
stderr_val = stderr_future.result()
303-
if stderr_val:
304-
print(stderr_val, end="", file=sys.stderr)
305-
stderr_future.release()
306-
if display_expr:
307-
reprs.display_repr(repr_val.result()) # This blocks!
308-
repr_val.release()
316+
if _supports_async_output():
317+
# Display in `out` cell when data is ready: non-blocking
318+
from IPython.display import display
319+
from ipywidgets import Output
320+
321+
out = Output()
322+
display(out)
323+
# Can we show `distributed.progress` right here?
324+
stdout_future.add_done_callback(
325+
partial(_display_outputs, out, stderr_future, repr_future)
326+
)
327+
elif display_expr or has_print:
328+
# blocks!
329+
stdout_val = stdout_future.result()
330+
stdout_future.release()
331+
if stdout_val:
332+
print(stdout_val, end="")
333+
stderr_val = stderr_future.result()
334+
stderr_future.release()
335+
if stderr_val:
336+
print(stderr_val, end="", file=sys.stderr)
337+
if display_expr:
338+
repr_val = repr_future.result()
339+
repr_future.release()
340+
if repr_val is not None:
341+
reprs.display_repr(repr_val)
309342
elif self._where == "locally":
310343
# Run locally. This is handy for testing and debugging.
311344
results = self._magic_func()
312345
for name in names:
313346
self.data[name] = results[name]
314347
if display_expr:
315-
reprs.IPython.display.display(results.return_value)
348+
from IPython.dislpay import display
349+
350+
display(results.return_value)
316351
elif self._where == "later":
317352
return True
318353
else:
@@ -342,11 +377,28 @@ class Get(Run):
342377
_gather_data = True
343378

344379

380+
def _display_outputs(out, stderr_future, repr_future, stdout_future):
381+
stdout_val = stdout_future.result()
382+
stderr_val = stderr_future.result()
383+
if repr_future is not None:
384+
repr_val = repr_future.result()
385+
else:
386+
repr_val = None
387+
if stdout_val or stderr_val or repr_val is not None:
388+
with out:
389+
if stdout_val:
390+
print(stdout_val, end="")
391+
if stderr_val:
392+
print(stderr_val, end="", file=sys.stderr)
393+
if repr_val is not None:
394+
reprs.display_repr(repr_val)
395+
396+
345397
def abracadabra(runner):
346398
# Create a new function from the code block of the context.
347399
# For now, we require that the source code is available.
348400
source = "def _afar_magic_():\n" + "".join(runner.context_body)
349-
func, display_expr = create_func(source, runner._frame.f_globals, reprs.in_ipython())
401+
func, display_expr = create_func(source, runner._frame.f_globals, reprs.is_kernel())
350402

351403
# If no variable names were given, only get the last assignment
352404
names = runner.names
@@ -422,26 +474,48 @@ def __setstate__(self, state):
422474
self._scoped = innerscope.scoped_function(func, outer_scope)
423475

424476

477+
# Here's the plan: we'll capture all print statements to stdout and stderr
478+
# on the current thread. But, we need to leave the other threads alone!
479+
# So, use `threading.local` and a lock for some ugly capturing.
480+
class LocalPrint(threading.local):
481+
printer = None
482+
483+
def __call__(self, *args, **kwargs):
484+
return self.printer(*args, **kwargs)
485+
486+
425487
class RecordPrint:
488+
n = 0
489+
local_print = LocalPrint()
490+
print_lock = threading.Lock()
491+
426492
def __init__(self):
427493
self.stdout = io.StringIO()
428494
self.stderr = io.StringIO()
429495

430496
def __enter__(self):
431-
self.print = builtins.print
432-
builtins.print = self
497+
with self.print_lock:
498+
if RecordPrint.n == 0:
499+
LocalPrint.printer = builtins.print
500+
builtins.print = self.local_print
501+
RecordPrint.n += 1
502+
self.local_print.printer = self
433503
return self
434504

435505
def __exit__(self, exc_type, exc_value, exc_traceback):
436-
builtins.print = self.print
506+
with self.print_lock:
507+
RecordPrint.n -= 1
508+
if RecordPrint.n == 0:
509+
builtins.print = LocalPrint.printer
510+
self.local_print.printer = LocalPrint.printer
437511
return False
438512

439513
def __call__(self, *args, file=None, **kwargs):
440514
if file is None or file is sys.stdout:
441515
file = self.stdout
442516
elif file is sys.stderr:
443517
file = self.stderr
444-
self.print(*args, **kwargs, file=file)
518+
LocalPrint.printer(*args, **kwargs, file=file)
445519

446520

447521
def run_afar(magic_func, names, futures):

afar/reprs.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@
33
import traceback
44
from types import CodeType, FunctionType
55

6-
try:
7-
import IPython
6+
from distributed.utils import is_kernel # noqa
87

9-
def in_ipython():
10-
return IPython.get_ipython() is not None
118

12-
13-
except ImportError:
14-
15-
def in_ipython():
9+
def in_terminal():
10+
if "IPython" not in sys.modules: # IPython hasn't been imported
1611
return False
12+
from IPython import get_ipython
13+
14+
return type(get_ipython()).__name__ == "TerminalInteractiveShell"
1715

1816

1917
if hasattr(CodeType, "replace"):
@@ -88,7 +86,9 @@ def __getattr__(self, attr):
8886

8987
def get_repr_methods():
9088
"""List of repr methods that IPython/Jupyter tries to use"""
91-
ip = IPython.get_ipython()
89+
from IPython import get_ipython
90+
91+
ip = get_ipython()
9292
if ip is None:
9393
return
9494
attr_recorder = AttrRecorder()
@@ -101,6 +101,8 @@ def repr_afar(val, repr_methods):
101101
102102
We call this on a remote object.
103103
"""
104+
if val is None:
105+
return None
104106
for method_name in repr_methods:
105107
method = getattr(val, method_name, None)
106108
if method is None:
@@ -134,11 +136,15 @@ def display_repr(results):
134136
if is_exception:
135137
print(val, file=sys.stderr)
136138
return
139+
if val is None and method_name is None:
140+
return
137141
if method_name == "_ipython_display_":
138142
val._ipython_display_()
139143
else:
144+
from IPython.display import display
145+
140146
mimic = MimicRepr(val, method_name)
141-
IPython.display.display(mimic)
147+
display(mimic)
142148

143149

144150
class MimicRepr:

0 commit comments

Comments
 (0)