Skip to content

Commit af43883

Browse files
matthew-brettlarsoner
authored andcommitted
NF: add version of Paul Ivanov's slice viewer
Thanks Paul...
1 parent b66d63b commit af43883

File tree

1 file changed

+215
-0
lines changed

1 file changed

+215
-0
lines changed

nibabel/viewers.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
""" Utilities for viewing images
2+
3+
Includes version of OrthoSlicer3D code by our own Paul Ivanov
4+
"""
5+
from __future__ import division, print_function
6+
7+
import numpy as np
8+
9+
from .optpkg import optional_package
10+
11+
plt, _, _ = optional_package('matplotlib.pyplot')
12+
mpl_img, _, _ = optional_package('matplotlib.image')
13+
14+
# Assumes the following layout
15+
#
16+
# ^ +---------+ ^ +---------+
17+
# | | | | | |
18+
# | | | |
19+
# z | 2 | z | 3 |
20+
# | | | |
21+
# | | | | | |
22+
# v +---------+ v +---------+
23+
# <-- x --> <-- y -->
24+
# ^ +---------+
25+
# | | |
26+
# | |
27+
# y | 1 |
28+
# | |
29+
# | | |
30+
# v +---------+
31+
# <-- x -->
32+
33+
class OrthoSlicer3D(object):
34+
"""Orthogonal-plane slicer.
35+
36+
OrthoSlicer3d expects 3-dimensional data, and by default it creates a
37+
figure with 3 axes, one for each slice orientation.
38+
39+
There are two modes, "following on" and "following off". In "following on"
40+
mode, moving the mouse in any one axis will select out the corresponding
41+
slices in the other two. The mode is "following off" when the figure is
42+
first created. Clicking the left mouse button toggles mouse following and
43+
triggers a full redraw (to update the ticks, for example). Scrolling up and
44+
down moves the slice up and down in the current axis.
45+
46+
Example
47+
-------
48+
import numpy as np
49+
a = np.sin(np.linspace(0,np.pi,20))
50+
b = np.sin(np.linspace(0,np.pi*5,20))
51+
data = np.outer(a,b)[..., np.newaxis]*a
52+
OrthoSlicer3D(data).show()
53+
"""
54+
def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
55+
pcnt_range=None):
56+
"""
57+
Parameters
58+
----------
59+
data : 3 dimensional ndarray
60+
The data that will be displayed by the slicer
61+
axes : None or length 3 sequence of mpl.Axes, optional
62+
3 axes instances for the X, Y, and Z slices, or None (default)
63+
aspect_ratio : float or length 3 sequence, optional
64+
stretch factors for X, Y, Z directions
65+
cmap : colormap identifier, optional
66+
String or cmap instance specifying colormap. Will be passed as
67+
``cmap`` argument to ``plt.imshow``.
68+
pcnt_range : length 2 sequence, optional
69+
Percentile range over which to scale image for display. If None,
70+
scale between image mean and max. If sequence, min and max
71+
percentile over which to scale image.
72+
"""
73+
data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension
74+
aspect_ratio = np.array(aspect_ratio)
75+
if axes is None: # make the axes
76+
# ^ +---------+ ^ +---------+
77+
# | | | | | |
78+
# | | | |
79+
# z | 2 | z | 3 |
80+
# | | | |
81+
# | | | | | |
82+
# v +---------+ v +---------+
83+
# <-- x --> <-- y -->
84+
# ^ +---------+
85+
# | | |
86+
# | |
87+
# y | 1 |
88+
# | |
89+
# | | |
90+
# v +---------+
91+
# <-- x -->
92+
fig = plt.figure()
93+
x, y, z = data_shape * aspect_ratio
94+
maxw = float(x + y)
95+
maxh = float(y + z)
96+
yh = y / maxh
97+
xw = x / maxw
98+
yw = y / maxw
99+
zh = z / maxh
100+
# z slice (if usual transverse acquisition => axial slice)
101+
ax1 = fig.add_axes((0., 0., xw, yh))
102+
# y slice (usually coronal)
103+
ax2 = fig.add_axes((0, yh, xw, zh))
104+
# x slice (usually sagittal)
105+
ax3 = fig.add_axes((xw, yh, yw, zh))
106+
axes = (ax1, ax2, ax3)
107+
else:
108+
if not np.all(aspect_ratio == 1):
109+
raise ValueError('Aspect ratio must be 1 for external axes')
110+
ax1, ax2, ax3 = axes
111+
112+
self.data = data
113+
114+
if pcnt_range is None:
115+
vmin, vmax = data.min(), data.max()
116+
else:
117+
vmin, vmax = np.percentile(data, pcnt_range)
118+
119+
kw = dict(vmin=vmin,
120+
vmax=vmax,
121+
aspect='auto',
122+
interpolation='nearest',
123+
cmap=cmap,
124+
origin='lower')
125+
# Start midway through each axis
126+
st_x, st_y, st_z = (data_shape - 1) / 2.
127+
n_x, n_y, n_z = data_shape
128+
z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T
129+
y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T
130+
x_get_slice = lambda i: self.data[min(i, n_x-1), :, :].T
131+
im1 = ax1.imshow(z_get_slice(st_z), **kw)
132+
im2 = ax2.imshow(y_get_slice(st_y), **kw)
133+
im3 = ax3.imshow(x_get_slice(st_x), **kw)
134+
im1.get_slice, im2.get_slice, im3.get_slice = (
135+
z_get_slice, y_get_slice, x_get_slice)
136+
# idx is the current slice number for each panel
137+
im1.idx, im2.idx, im3.idx = st_z, st_y, st_x
138+
# set the maximum dimensions for indexing
139+
im1.size, im2.size, im3.size = n_z, n_y, n_x
140+
# setup pairwise connections between the slice dimensions
141+
im1.imx = im3 # x move in panel 1 (usually axial)
142+
im1.imy = im2 # y move in panel 1
143+
im2.imx = im3 # x move in panel 2 (usually coronal)
144+
im2.imy = im1
145+
im3.imx = im2 # x move in panel 3 (usually sagittal)
146+
im3.imy = im1
147+
148+
self.follow = False
149+
self.figs = set([ax.figure for ax in axes])
150+
for fig in self.figs:
151+
fig.canvas.mpl_connect('button_press_event', self.on_click)
152+
fig.canvas.mpl_connect('scroll_event', self.on_scroll)
153+
fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove)
154+
155+
def show(self):
156+
""" Show the slicer; convenience for ``plt.show()``
157+
"""
158+
plt.show()
159+
160+
def _axis_artist(self, event):
161+
""" Return artist if within axes, and is an image, else None
162+
"""
163+
if not getattr(event, 'inaxes'):
164+
return None
165+
artist = event.inaxes.images[0]
166+
return artist if isinstance(artist, mpl_img.AxesImage) else None
167+
168+
def on_click(self, event):
169+
if event.button == 1:
170+
self.follow = not self.follow
171+
plt.draw()
172+
173+
def on_scroll(self, event):
174+
assert event.button in ('up', 'down')
175+
im = self._axis_artist(event)
176+
if im is None:
177+
return
178+
im.idx += 1 if event.button == 'up' else -1
179+
im.idx %= im.size
180+
im.set_data(im.get_slice(im.idx))
181+
ax = im.axes
182+
ax.draw_artist(im)
183+
ax.figure.canvas.blit(ax.bbox)
184+
185+
def on_mousemove(self, event):
186+
if not self.follow:
187+
return
188+
im = self._axis_artist(event)
189+
if im is None:
190+
return
191+
ax = im.axes
192+
imx, imy = im.imx, im.imy
193+
x, y = np.round((event.xdata, event.ydata)).astype(int)
194+
imx.set_data(imx.get_slice(x))
195+
imy.set_data(imy.get_slice(y))
196+
imx.idx = x
197+
imy.idx = y
198+
for i in imx, imy:
199+
ax = i.axes
200+
ax.draw_artist(i)
201+
ax.figure.canvas.blit(ax.bbox)
202+
203+
204+
if __name__ == '__main__':
205+
a = np.sin(np.linspace(0,np.pi,20))
206+
b = np.sin(np.linspace(0,np.pi*5,20))
207+
data = np.outer(a,b)[..., np.newaxis]*a
208+
# all slices
209+
OrthoSlicer3D(data).show()
210+
211+
# broken out into three separate figures
212+
f, ax1 = plt.subplots()
213+
f, ax2 = plt.subplots()
214+
f, ax3 = plt.subplots()
215+
OrthoSlicer3D(data, axes=(ax1, ax2, ax3)).show()

0 commit comments

Comments
 (0)