Skip to content

Commit 776dc61

Browse files
committed
add reportlet features from sdcflows
1 parent 79a1e5c commit 776dc61

File tree

3 files changed

+157
-0
lines changed

3 files changed

+157
-0
lines changed

dmriprep/interfaces/reportlets.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Interfaces to generate speciality reportlets."""
2+
from nilearn.image import threshold_img, load_img
3+
from niworkflows import NIWORKFLOWS_LOG
4+
from niworkflows.viz.utils import cuts_from_bbox, compose_view
5+
from nipype.interfaces.base import File, isdefined
6+
from nipype.interfaces.mixins import reporting
7+
8+
from ..viz.utils import plot_registration, coolwarm_transparent
9+
10+
11+
class FieldmapReportletInputSpec(reporting.ReportCapableInputSpec):
12+
reference = File(exists=True, mandatory=True, desc="input reference")
13+
fieldmap = File(exists=True, mandatory=True, desc="input fieldmap")
14+
mask = File(exists=True, desc="brain mask")
15+
out_report = File(
16+
"report.svg", usedefault=True, desc="filename for the visual report"
17+
)
18+
19+
20+
class FieldmapReportlet(reporting.ReportCapableInterface):
21+
"""An abstract mixin to registration nipype interfaces."""
22+
23+
_n_cuts = 7
24+
input_spec = FieldmapReportletInputSpec
25+
output_spec = reporting.ReportCapableOutputSpec
26+
27+
def __init__(self, **kwargs):
28+
"""Instantiate FieldmapReportlet."""
29+
self._n_cuts = kwargs.pop("n_cuts", self._n_cuts)
30+
super(FieldmapReportlet, self).__init__(generate_report=True, **kwargs)
31+
32+
def _run_interface(self, runtime):
33+
return runtime
34+
35+
def _generate_report(self):
36+
"""Generate a reportlet."""
37+
NIWORKFLOWS_LOG.info("Generating visual report")
38+
39+
refnii = load_img(self.inputs.reference)
40+
fmapnii = load_img(self.inputs.fieldmap)
41+
contour_nii = (
42+
load_img(self.inputs.mask) if isdefined(self.inputs.mask) else None
43+
)
44+
mask_nii = threshold_img(refnii, 1e-3)
45+
cuts = cuts_from_bbox(contour_nii or mask_nii, cuts=self._n_cuts)
46+
fmapdata = fmapnii.get_fdata()
47+
vmax = max(fmapdata.max(), abs(fmapdata.min()))
48+
49+
# Call composer
50+
compose_view(
51+
plot_registration(
52+
refnii,
53+
"fixed-image",
54+
estimate_brightness=True,
55+
cuts=cuts,
56+
label="reference",
57+
contour=contour_nii,
58+
compress=False,
59+
),
60+
plot_registration(
61+
fmapnii,
62+
"moving-image",
63+
estimate_brightness=True,
64+
cuts=cuts,
65+
label="fieldmap (Hz)",
66+
contour=contour_nii,
67+
compress=False,
68+
plot_params={
69+
"cmap": coolwarm_transparent(),
70+
"vmax": vmax,
71+
"vmin": -vmax,
72+
},
73+
),
74+
out_file=self._out_report,
75+
)

dmriprep/viz/__init__.py

Whitespace-only changes.

dmriprep/viz/utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Visualization tooling."""
2+
3+
4+
def plot_registration(
5+
anat_nii,
6+
div_id,
7+
plot_params=None,
8+
order=("z", "x", "y"),
9+
cuts=None,
10+
estimate_brightness=False,
11+
label=None,
12+
contour=None,
13+
compress="auto",
14+
):
15+
"""
16+
Plot the foreground and background views.
17+
Default order is: axial, coronal, sagittal
18+
"""
19+
from uuid import uuid4
20+
21+
from lxml import etree
22+
from nilearn.plotting import plot_anat
23+
from svgutils.transform import SVGFigure
24+
from niworkflows.viz.utils import robust_set_limits, extract_svg, SVGNS
25+
26+
plot_params = plot_params or {}
27+
28+
# Use default MNI cuts if none defined
29+
if cuts is None:
30+
raise NotImplementedError # TODO
31+
32+
out_files = []
33+
if estimate_brightness:
34+
plot_params = robust_set_limits(anat_nii.get_data().reshape(-1), plot_params)
35+
36+
# Plot each cut axis
37+
for i, mode in enumerate(list(order)):
38+
plot_params["display_mode"] = mode
39+
plot_params["cut_coords"] = cuts[mode]
40+
if i == 0:
41+
plot_params["title"] = label
42+
else:
43+
plot_params["title"] = None
44+
45+
# Generate nilearn figure
46+
display = plot_anat(anat_nii, **plot_params)
47+
if contour is not None:
48+
display.add_contours(contour, colors="g", levels=[0.5], linewidths=0.5)
49+
50+
svg = extract_svg(display, compress=compress)
51+
display.close()
52+
53+
# Find and replace the figure_1 id.
54+
xml_data = etree.fromstring(svg)
55+
find_text = etree.ETXPath("//{%s}g[@id='figure_1']" % SVGNS)
56+
find_text(xml_data)[0].set("id", "%s-%s-%s" % (div_id, mode, uuid4()))
57+
58+
svg_fig = SVGFigure()
59+
svg_fig.root = xml_data
60+
out_files.append(svg_fig)
61+
62+
return out_files
63+
64+
65+
def coolwarm_transparent():
66+
"""Modify the coolwarm color scale to have full transparency around the middle."""
67+
import numpy as np
68+
import matplotlib.pylab as pl
69+
from matplotlib.colors import ListedColormap
70+
71+
# Choose colormap
72+
cmap = pl.cm.coolwarm
73+
74+
# Get the colormap colors
75+
my_cmap = cmap(np.arange(cmap.N))
76+
77+
# Set alpha
78+
alpha = np.ones(cmap.N)
79+
alpha[128:160] = np.linspace(0, 1, len(alpha[128:160]))
80+
alpha[96:128] = np.linspace(1, 0, len(alpha[96:128]))
81+
my_cmap[:, -1] = alpha
82+
return ListedColormap(my_cmap)

0 commit comments

Comments
 (0)