Skip to content

Commit 8bb7fb0

Browse files
committed
ENH: Add crosshairs, modify mode
1 parent af43883 commit 8bb7fb0

File tree

4 files changed

+138
-59
lines changed

4 files changed

+138
-59
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# munges each line before executing it to print out the exit status. It's okay
44
# for it to be on multiple physical lines, so long as you remember: - There
55
# can't be any leading "-"s - All newlines will be removed, so use ";"s
6+
67
language: python
78

89
# Run jobs on container-based infrastructure, can be overridden per job

nibabel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from .imageclasses import class_map, ext_map, all_image_classes
6565
from . import trackvis
6666
from . import mriutils
67+
from . import viewers
6768

6869
# be friendly on systems with ancient numpy -- no tests, but at least
6970
# importable

nibabel/tests/test_viewers.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
4+
#
5+
# See COPYING file distributed along with the NiBabel package for the
6+
# copyright and license terms.
7+
#
8+
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9+
10+
import numpy as np
11+
from collections import namedtuple as nt
12+
13+
from ..optpkg import optional_package
14+
from ..viewers import OrthoSlicer3D
15+
16+
from numpy.testing.decorators import skipif
17+
18+
from nose.tools import assert_raises
19+
20+
plt, has_mpl = optional_package('matplotlib.pyplot')[:2]
21+
needs_mpl = skipif(not has_mpl, 'These tests need matplotlib')
22+
23+
24+
@needs_mpl
25+
def test_viewer():
26+
# Test viewer
27+
a = np.sin(np.linspace(0, np.pi, 20))
28+
b = np.sin(np.linspace(0, np.pi*5, 30))
29+
data = np.outer(a, b)[..., np.newaxis] * a
30+
viewer = OrthoSlicer3D(data)
31+
plt.draw()
32+
33+
# fake some events
34+
viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes
35+
viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes
36+
# tracking on
37+
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
38+
None, 1))
39+
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
40+
plt.gca(), 1))
41+
# tracking off
42+
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
43+
None, None))
44+
viewer.close()
45+
46+
# other cases
47+
fig, axes = plt.subplots(1, 3)
48+
plt.close(fig)
49+
OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes)
50+
assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2, 3],
51+
axes=axes)

nibabel/viewers.py

Lines changed: 85 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
""" Utilities for viewing images
22
3-
Includes version of OrthoSlicer3D code by our own Paul Ivanov
3+
Includes version of OrthoSlicer3D code originally written by our own
4+
Paul Ivanov.
45
"""
56
from __future__ import division, print_function
67

78
import numpy as np
9+
from functools import partial
810

911
from .optpkg import optional_package
1012

@@ -30,26 +32,32 @@
3032
# v +---------+
3133
# <-- x -->
3234

35+
36+
def _set_viewer_slice(idx, im):
37+
"""Helper to set a viewer slice number"""
38+
im.idx = idx
39+
im.set_data(im.get_slice(im.idx))
40+
for fun in im.cross_setters:
41+
fun([idx] * 2)
42+
43+
3344
class OrthoSlicer3D(object):
3445
"""Orthogonal-plane slicer.
3546
3647
OrthoSlicer3d expects 3-dimensional data, and by default it creates a
3748
figure with 3 axes, one for each slice orientation.
3849
39-
There are two modes, "following on" and "following off". In "following on"
40-
mode, moving the mouse in any one axis will select out the corresponding
41-
slices in the other two. The mode is "following off" when the figure is
42-
first created. Clicking the left mouse button toggles mouse following and
43-
triggers a full redraw (to update the ticks, for example). Scrolling up and
50+
Clicking and dragging the mouse in any one axis will select out the
51+
corresponding slices in the other two. Scrolling up and
4452
down moves the slice up and down in the current axis.
4553
4654
Example
4755
-------
48-
import numpy as np
49-
a = np.sin(np.linspace(0,np.pi,20))
50-
b = np.sin(np.linspace(0,np.pi*5,20))
51-
data = np.outer(a,b)[..., np.newaxis]*a
52-
OrthoSlicer3D(data).show()
56+
>>> import numpy as np
57+
>>> a = np.sin(np.linspace(0,np.pi,20))
58+
>>> b = np.sin(np.linspace(0,np.pi*5,20))
59+
>>> data = np.outer(a,b)[..., np.newaxis]*a
60+
>>> OrthoSlicer3D(data).show()
5361
"""
5462
def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
5563
pcnt_range=None):
@@ -70,9 +78,9 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
7078
scale between image mean and max. If sequence, min and max
7179
percentile over which to scale image.
7280
"""
73-
data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension
74-
aspect_ratio = np.array(aspect_ratio)
75-
if axes is None: # make the axes
81+
data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension
82+
aspect_ratio = np.array(aspect_ratio, float)
83+
if axes is None: # make the axes
7684
# ^ +---------+ ^ +---------+
7785
# | | | | | |
7886
# | | | |
@@ -122,8 +130,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
122130
interpolation='nearest',
123131
cmap=cmap,
124132
origin='lower')
133+
125134
# Start midway through each axis
126135
st_x, st_y, st_z = (data_shape - 1) / 2.
136+
sts = (st_x, st_y, st_z)
127137
n_x, n_y, n_z = data_shape
128138
z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T
129139
y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T
@@ -133,22 +143,51 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
133143
im3 = ax3.imshow(x_get_slice(st_x), **kw)
134144
im1.get_slice, im2.get_slice, im3.get_slice = (
135145
z_get_slice, y_get_slice, x_get_slice)
146+
self._ims = (im1, im2, im3)
147+
136148
# idx is the current slice number for each panel
137149
im1.idx, im2.idx, im3.idx = st_z, st_y, st_x
150+
138151
# set the maximum dimensions for indexing
139152
im1.size, im2.size, im3.size = n_z, n_y, n_x
153+
154+
# set up axis crosshairs
155+
colors = ['r', 'g', 'b']
156+
for ax, im, idx_1, idx_2 in zip(axes, self._ims, [0, 0, 1], [1, 2, 2]):
157+
im.x_line = ax.plot([sts[idx_1]] * 2,
158+
[-0.5, data.shape[idx_2] - 0.5],
159+
color=colors[idx_1], linestyle='-',
160+
alpha=0.25)[0]
161+
im.y_line = ax.plot([-0.5, data.shape[idx_1] - 0.5],
162+
[sts[idx_2]] * 2,
163+
color=colors[idx_2], linestyle='-',
164+
alpha=0.25)[0]
165+
ax.axis('tight')
166+
ax.patch.set_visible(False)
167+
ax.set_frame_on(False)
168+
ax.axes.get_yaxis().set_visible(False)
169+
ax.axes.get_xaxis().set_visible(False)
170+
171+
# monkey-patch some functions
172+
im1.set_viewer_slice = partial(_set_viewer_slice, im=im1)
173+
im2.set_viewer_slice = partial(_set_viewer_slice, im=im2)
174+
im3.set_viewer_slice = partial(_set_viewer_slice, im=im3)
175+
140176
# setup pairwise connections between the slice dimensions
141-
im1.imx = im3 # x move in panel 1 (usually axial)
142-
im1.imy = im2 # y move in panel 1
143-
im2.imx = im3 # x move in panel 2 (usually coronal)
144-
im2.imy = im1
145-
im3.imx = im2 # x move in panel 3 (usually sagittal)
146-
im3.imy = im1
147-
148-
self.follow = False
177+
im1.x_im = im3 # x move in panel 1 (usually axial)
178+
im1.y_im = im2 # y move in panel 1
179+
im2.x_im = im3 # x move in panel 2 (usually coronal)
180+
im2.y_im = im1 # y move in panel 2
181+
im3.x_im = im2 # x move in panel 3 (usually sagittal)
182+
im3.y_im = im1 # y move in panel 3
183+
184+
# when an index changes, which crosshairs need to be updated
185+
im1.cross_setters = [im2.y_line.set_ydata, im3.y_line.set_ydata]
186+
im2.cross_setters = [im1.y_line.set_ydata, im3.x_line.set_xdata]
187+
im3.cross_setters = [im1.x_line.set_xdata, im2.x_line.set_xdata]
188+
149189
self.figs = set([ax.figure for ax in axes])
150190
for fig in self.figs:
151-
fig.canvas.mpl_connect('button_press_event', self.on_click)
152191
fig.canvas.mpl_connect('scroll_event', self.on_scroll)
153192
fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove)
154193

@@ -157,59 +196,46 @@ def show(self):
157196
"""
158197
plt.show()
159198

199+
def close(self):
200+
"""Close the viewer figures
201+
"""
202+
for f in self.figs:
203+
plt.close(f)
204+
160205
def _axis_artist(self, event):
161-
""" Return artist if within axes, and is an image, else None
206+
"""Return artist if within axes, and is an image, else None
162207
"""
163208
if not getattr(event, 'inaxes'):
164209
return None
165210
artist = event.inaxes.images[0]
166211
return artist if isinstance(artist, mpl_img.AxesImage) else None
167212

168-
def on_click(self, event):
169-
if event.button == 1:
170-
self.follow = not self.follow
171-
plt.draw()
172-
173213
def on_scroll(self, event):
174214
assert event.button in ('up', 'down')
175215
im = self._axis_artist(event)
176216
if im is None:
177217
return
178-
im.idx += 1 if event.button == 'up' else -1
179-
im.idx %= im.size
180-
im.set_data(im.get_slice(im.idx))
181-
ax = im.axes
182-
ax.draw_artist(im)
183-
ax.figure.canvas.blit(ax.bbox)
218+
idx = (im.idx + (1 if event.button == 'up' else -1))
219+
idx = max(min(idx, im.size - 1), 0)
220+
im.set_viewer_slice(idx)
221+
self._draw_ims()
184222

185223
def on_mousemove(self, event):
186-
if not self.follow:
224+
if event.button != 1: # only enabled while dragging
187225
return
188226
im = self._axis_artist(event)
189227
if im is None:
190228
return
191-
ax = im.axes
192-
imx, imy = im.imx, im.imy
229+
x_im, y_im = im.x_im, im.y_im
193230
x, y = np.round((event.xdata, event.ydata)).astype(int)
194-
imx.set_data(imx.get_slice(x))
195-
imy.set_data(imy.get_slice(y))
196-
imx.idx = x
197-
imy.idx = y
198-
for i in imx, imy:
199-
ax = i.axes
200-
ax.draw_artist(i)
231+
for i, idx in zip((x_im, y_im), (x, y)):
232+
i.set_viewer_slice(idx)
233+
self._draw_ims()
234+
235+
def _draw_ims(self):
236+
for im in self._ims:
237+
ax = im.axes
238+
ax.draw_artist(im)
239+
ax.draw_artist(im.x_line)
240+
ax.draw_artist(im.y_line)
201241
ax.figure.canvas.blit(ax.bbox)
202-
203-
204-
if __name__ == '__main__':
205-
a = np.sin(np.linspace(0,np.pi,20))
206-
b = np.sin(np.linspace(0,np.pi*5,20))
207-
data = np.outer(a,b)[..., np.newaxis]*a
208-
# all slices
209-
OrthoSlicer3D(data).show()
210-
211-
# broken out into three separate figures
212-
f, ax1 = plt.subplots()
213-
f, ax2 = plt.subplots()
214-
f, ax3 = plt.subplots()
215-
OrthoSlicer3D(data, axes=(ax1, ax2, ax3)).show()

0 commit comments

Comments
 (0)