Skip to content

Commit 843897b

Browse files
committed
ENH: Show registration reportlets inline within Jupyter notebooks
Add a little wrapper around our registration reportlet to allow them to render (and flicker) within a notebook. Created for nipy/nitransforms#93.
1 parent e43c6c1 commit 843897b

File tree

2 files changed

+67
-10
lines changed

2 files changed

+67
-10
lines changed

niworkflows/viz/notebook.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Visualization component for Jupyter Notebooks."""
2+
from pathlib import Path
3+
import numpy as np
4+
import nibabel as nb
5+
from .utils import compose_view, plot_registration, cuts_from_bbox
6+
from IPython.display import SVG, display as _disp
7+
8+
9+
def display(
10+
fixed_image,
11+
moving_image,
12+
contour=None,
13+
cuts=None,
14+
fixed_label="F",
15+
moving_label="M",
16+
):
17+
"""Plot the flickering panels to show a registration process."""
18+
if isinstance(fixed_image, (str, Path)):
19+
fixed_image = nb.load(str(fixed_image))
20+
if isinstance(moving_image, (str, Path)):
21+
moving_image = nb.load(str(moving_image))
22+
23+
if cuts is None:
24+
n_cuts = 7
25+
if contour is not None:
26+
if isinstance(contour, (str, Path)):
27+
contour = nb.load(str(contour))
28+
cuts = cuts_from_bbox(contour, cuts=n_cuts)
29+
else:
30+
hdr = fixed_image.header.copy()
31+
hdr.set_data_dtype('uint8')
32+
mask_nii = nb.Nifti1Image(
33+
np.ones(fixed_image.shape, dtype='uint8'),
34+
fixed_image.affine, hdr)
35+
cuts = cuts_from_bbox(mask_nii, cuts=n_cuts)
36+
37+
# Call composer
38+
_disp(SVG(compose_view(
39+
plot_registration(fixed_image, 'fixed-image',
40+
estimate_brightness=True,
41+
cuts=cuts,
42+
label=fixed_label,
43+
contour=contour,
44+
compress=False),
45+
plot_registration(moving_image, 'moving-image',
46+
estimate_brightness=True,
47+
cuts=cuts,
48+
label=moving_label,
49+
contour=contour,
50+
compress=False),
51+
)))

niworkflows/viz/utils.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Helper tools for visualization purposes"""
44
from pathlib import Path
55
from shutil import which
6+
from tempfile import TemporaryDirectory
67
import subprocess
78
import base64
89
import re
@@ -327,10 +328,15 @@ def plot_registration(anat_nii, div_id, plot_params=None,
327328

328329

329330
def compose_view(bg_svgs, fg_svgs, ref=0, out_file='report.svg'):
330-
"""
331-
Composes the input svgs into one standalone svg and inserts
332-
the CSS code for the flickering animation
333-
"""
331+
"""Compose the input svgs into one standalone svg with CSS flickering animation."""
332+
out_file = Path(out_file).absolute()
333+
out_file.write_text("\n".join(
334+
_compose_view(bg_svgs, fg_svgs, ref=ref)
335+
))
336+
return str(out_file)
337+
338+
339+
def _compose_view(bg_svgs, fg_svgs, ref=0):
334340

335341
if fg_svgs is None:
336342
fg_svgs = []
@@ -379,11 +385,12 @@ def compose_view(bg_svgs, fg_svgs, ref=0, out_file='report.svg'):
379385
fig.root.attrib.pop("width")
380386
fig.root.attrib.pop("height")
381387
fig.root.set("preserveAspectRatio", "xMidYMid meet")
382-
out_file = Path(out_file).absolute()
383-
fig.save(str(out_file))
384388

385-
# Post processing
386-
svg = out_file.read_text().splitlines()
389+
with TemporaryDirectory() as tmpdirname:
390+
out_file = Path(tmpdirname) / "tmp.svg"
391+
fig.save(str(out_file))
392+
# Post processing
393+
svg = out_file.read_text().splitlines()
387394

388395
# Remove <?xml... line
389396
if svg[0].startswith("<?xml"):
@@ -398,8 +405,7 @@ def compose_view(bg_svgs, fg_svgs, ref=0, out_file='report.svg'):
398405
.foreground-svg:hover { animation-play-state: running;}
399406
</style>""" % tuple([uuid4()] * 2))
400407

401-
out_file.write_text("\n".join(svg))
402-
return str(out_file)
408+
return svg
403409

404410

405411
def transform_to_2d(data, max_axis):

0 commit comments

Comments
 (0)