|
| 1 | +"""Visualization tooling.""" |
| 2 | + |
| 3 | + |
| 4 | +def plot_registration( |
| 5 | + anat_nii, |
| 6 | + div_id, |
| 7 | + plot_params=None, |
| 8 | + order=("z", "x", "y"), |
| 9 | + cuts=None, |
| 10 | + estimate_brightness=False, |
| 11 | + label=None, |
| 12 | + contour=None, |
| 13 | + compress="auto", |
| 14 | +): |
| 15 | + """ |
| 16 | + Plot the foreground and background views. |
| 17 | + Default order is: axial, coronal, sagittal |
| 18 | + """ |
| 19 | + from uuid import uuid4 |
| 20 | + |
| 21 | + from lxml import etree |
| 22 | + from nilearn.plotting import plot_anat |
| 23 | + from svgutils.transform import SVGFigure |
| 24 | + from niworkflows.viz.utils import robust_set_limits, extract_svg, SVGNS |
| 25 | + |
| 26 | + plot_params = plot_params or {} |
| 27 | + |
| 28 | + # Use default MNI cuts if none defined |
| 29 | + if cuts is None: |
| 30 | + raise NotImplementedError # TODO |
| 31 | + |
| 32 | + out_files = [] |
| 33 | + if estimate_brightness: |
| 34 | + plot_params = robust_set_limits(anat_nii.get_data().reshape(-1), plot_params) |
| 35 | + |
| 36 | + # Plot each cut axis |
| 37 | + for i, mode in enumerate(list(order)): |
| 38 | + plot_params["display_mode"] = mode |
| 39 | + plot_params["cut_coords"] = cuts[mode] |
| 40 | + if i == 0: |
| 41 | + plot_params["title"] = label |
| 42 | + else: |
| 43 | + plot_params["title"] = None |
| 44 | + |
| 45 | + # Generate nilearn figure |
| 46 | + display = plot_anat(anat_nii, **plot_params) |
| 47 | + if contour is not None: |
| 48 | + display.add_contours(contour, colors="g", levels=[0.5], linewidths=0.5) |
| 49 | + |
| 50 | + svg = extract_svg(display, compress=compress) |
| 51 | + display.close() |
| 52 | + |
| 53 | + # Find and replace the figure_1 id. |
| 54 | + xml_data = etree.fromstring(svg) |
| 55 | + find_text = etree.ETXPath("//{%s}g[@id='figure_1']" % SVGNS) |
| 56 | + find_text(xml_data)[0].set("id", "%s-%s-%s" % (div_id, mode, uuid4())) |
| 57 | + |
| 58 | + svg_fig = SVGFigure() |
| 59 | + svg_fig.root = xml_data |
| 60 | + out_files.append(svg_fig) |
| 61 | + |
| 62 | + return out_files |
| 63 | + |
| 64 | + |
| 65 | +def coolwarm_transparent(): |
| 66 | + """Modify the coolwarm color scale to have full transparency around the middle.""" |
| 67 | + import numpy as np |
| 68 | + import matplotlib.pylab as pl |
| 69 | + from matplotlib.colors import ListedColormap |
| 70 | + |
| 71 | + # Choose colormap |
| 72 | + cmap = pl.cm.coolwarm |
| 73 | + |
| 74 | + # Get the colormap colors |
| 75 | + my_cmap = cmap(np.arange(cmap.N)) |
| 76 | + |
| 77 | + # Set alpha |
| 78 | + alpha = np.ones(cmap.N) |
| 79 | + alpha[128:160] = np.linspace(0, 1, len(alpha[128:160])) |
| 80 | + alpha[96:128] = np.linspace(1, 0, len(alpha[96:128])) |
| 81 | + my_cmap[:, -1] = alpha |
| 82 | + return ListedColormap(my_cmap) |
0 commit comments