Skip to content

Commit bf54667

Browse files
committed
FIX: Remove monkey patching
1 parent 11ff239 commit bf54667

File tree

2 files changed

+59
-56
lines changed

2 files changed

+59
-56
lines changed

nibabel/tests/test_viewers.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,28 @@ def test_viewer():
2828
b = np.sin(np.linspace(0, np.pi*5, 30))
2929
data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis]
3030
data = data * np.array([1., 2.]) # give it a # of volumes > 1
31-
viewer = OrthoSlicer3D(data)
31+
v = OrthoSlicer3D(data)
3232
plt.draw()
3333

3434
# fake some events, inside and outside axes
35-
viewer._on_scroll(nt('event', 'button inaxes key')('up', None, None))
36-
for ax in (viewer._axes['x'], viewer._axes['v']):
37-
viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, None))
38-
viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift'))
35+
v._on_scroll(nt('event', 'button inaxes key')('up', None, None))
36+
for ax in (v._axes['x'], v._axes['v']):
37+
v._on_scroll(nt('event', 'button inaxes key')('up', ax, None))
38+
v._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift'))
3939
# "click" outside axes, then once in each axis, then move without click
40-
viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
41-
None, 1))
42-
for ax in viewer._axes.values():
43-
viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
44-
ax, 1))
45-
viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
46-
None, None))
47-
viewer.set_indices(0, 1, 2)
48-
viewer.set_indices(v=10)
49-
viewer.close()
40+
v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, 1))
41+
for ax in v._axes.values():
42+
v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1))
43+
v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None))
44+
v.set_indices(0, 1, 2)
45+
v.set_indices(v=10)
46+
v.close()
5047

5148
# non-multi-volume
52-
viewer = OrthoSlicer3D(data[:, :, :, 0])
53-
assert_raises(ValueError, viewer.set_indices, v=10) # not multi-volume
54-
viewer._on_scroll(nt('event', 'button inaxes key')('up', viewer._axes['x'],
55-
'shift'))
56-
viewer._on_keypress(nt('event', 'key')('escape'))
49+
v = OrthoSlicer3D(data[:, :, :, 0])
50+
assert_raises(ValueError, v.set_indices, v=10) # not multi-volume
51+
v._on_scroll(nt('event', 'button inaxes key')('up', v._axes['x'], 'shift'))
52+
v._on_keypress(nt('event', 'key')('escape'))
5753

5854
# other cases
5955
fig, axes = plt.subplots(1, 4)
@@ -63,3 +59,4 @@ def test_viewer():
6359
OrthoSlicer3D(data, axes=axes[:3])
6460
assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2])
6561
assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0])
62+
assert_raises(ValueError, OrthoSlicer3D, data, affine=np.eye(3))

nibabel/viewers.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class OrthoSlicer3D(object):
3333
>>> OrthoSlicer3D(data).show() # doctest: +SKIP
3434
"""
3535
# Skip doctest above b/c not all systems have mpl installed
36-
def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
37-
pcnt_range=(1., 99.), figsize=(8, 8)):
36+
def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None,
37+
cmap='gray', pcnt_range=(1., 99.), figsize=(8, 8)):
3838
"""
3939
Parameters
4040
----------
@@ -46,13 +46,14 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
4646
or None (default).
4747
aspect_ratio : array-like, optional
4848
Stretch factors for X, Y, Z directions.
49+
affine : array-like | None
50+
Affine transform for the data. This is used to determine
51+
how the data should be sliced for plotting into the X, Y,
52+
and Z view axes. If None, identity is assumed.
4953
cmap : str | instance of cmap, optional
50-
String or cmap instance specifying colormap. Will be passed as
51-
``cmap`` argument to ``plt.imshow``.
54+
String or cmap instance specifying colormap.
5255
pcnt_range : array-like, optional
53-
Percentile range over which to scale image for display. If None,
54-
scale between image mean and max. If sequence, min and max
55-
percentile over which to scale image.
56+
Percentile range over which to scale image for display.
5657
figsize : tuple
5758
Figure size (in inches) to use if axes are None.
5859
"""
@@ -63,6 +64,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
6364
data = np.asanyarray(data)
6465
if data.ndim < 3:
6566
raise ValueError('data must have at least 3 dimensions')
67+
affine = np.array(affine, float) if affine is not None else np.eye(4)
68+
if affine.ndim != 2 or affine.shape != (4, 4):
69+
raise ValueError('affine must be a 4x4 matrix')
70+
self._affine = affine
6671
self._volume_dims = data.shape[3:]
6772
self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data
6873
self._data = data
@@ -116,17 +121,16 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
116121
labels = dict(z='ILSR', y='ALPR', x='AIPS')
117122

118123
# set up axis crosshairs
124+
self._crosshairs = dict()
119125
for type_, i_1, i_2 in zip('zyx', 'xxy', 'yzz'):
120-
ax = self._axes[type_]
121-
im = self._ims[type_]
122-
label = labels[type_]
123-
# add slice lines
124-
im.vert_line = ax.plot([self._idx[i_1]] * 2,
125-
[-0.5, self._sizes[i_2] - 0.5],
126-
color=colors[i_1], linestyle='-')[0]
127-
im.horiz_line = ax.plot([-0.5, self._sizes[i_1] - 0.5],
128-
[self._idx[i_2]] * 2,
129-
color=colors[i_2], linestyle='-')[0]
126+
ax, label = self._axes[type_], labels[type_]
127+
vert = ax.plot([self._idx[i_1]] * 2,
128+
[-0.5, self._sizes[i_2] - 0.5],
129+
color=colors[i_1], linestyle='-')[0]
130+
horiz = ax.plot([-0.5, self._sizes[i_1] - 0.5],
131+
[self._idx[i_2]] * 2,
132+
color=colors[i_2], linestyle='-')[0]
133+
self._crosshairs[type_] = dict(vert=vert, horiz=horiz)
130134
# add text labels (top, right, bottom, left)
131135
lims = [0, self._sizes[i_1], 0, self._sizes[i_2]]
132136
bump = 0.01
@@ -136,10 +140,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
136140
[lims[0] - bump * lims[1], lims[3] / 2.]]
137141
anchors = [['center', 'bottom'], ['left', 'center'],
138142
['center', 'top'], ['right', 'center']]
139-
im.texts = [ax.text(pos[0], pos[1], lab,
140-
horizontalalignment=anchor[0],
141-
verticalalignment=anchor[1])
142-
for pos, anchor, lab in zip(poss, anchors, label)]
143+
for pos, anchor, lab in zip(poss, anchors, label):
144+
ax.text(pos[0], pos[1], lab,
145+
horizontalalignment=anchor[0],
146+
verticalalignment=anchor[1])
143147
ax.axis(lims)
144148
ax.set_aspect(aspect_ratio[type_])
145149
ax.patch.set_visible(False)
@@ -172,18 +176,18 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
172176

173177
# when an index changes, which crosshairs need to be updated
174178
self._cross_setters = dict(
175-
x=[self._ims['z'].vert_line.set_xdata,
176-
self._ims['y'].vert_line.set_xdata],
177-
y=[self._ims['z'].horiz_line.set_ydata,
178-
self._ims['x'].vert_line.set_xdata],
179-
z=[self._ims['y'].horiz_line.set_ydata,
180-
self._ims['x'].horiz_line.set_ydata])
179+
x=[self._crosshairs['z']['vert'].set_xdata,
180+
self._crosshairs['y']['vert'].set_xdata],
181+
y=[self._crosshairs['z']['horiz'].set_ydata,
182+
self._crosshairs['x']['vert'].set_xdata],
183+
z=[self._crosshairs['y']['horiz'].set_ydata,
184+
self._crosshairs['x']['horiz'].set_ydata])
181185

182186
self._figs = set([a.figure for a in self._axes.values()])
183187
for fig in self._figs:
184188
fig.canvas.mpl_connect('scroll_event', self._on_scroll)
185-
fig.canvas.mpl_connect('motion_notify_event', self._on_mousemove)
186-
fig.canvas.mpl_connect('button_press_event', self._on_mousemove)
189+
fig.canvas.mpl_connect('motion_notify_event', self._on_mouse)
190+
fig.canvas.mpl_connect('button_press_event', self._on_mouse)
187191
fig.canvas.mpl_connect('key_press_event', self._on_keypress)
188192

189193
def show(self):
@@ -279,6 +283,7 @@ def _in_axis(self, event):
279283
return key
280284

281285
def _on_scroll(self, event):
286+
"""Handle mpl scroll wheel event"""
282287
assert event.button in ('up', 'down')
283288
key = self._in_axis(event)
284289
if key is None:
@@ -296,7 +301,8 @@ def _on_scroll(self, event):
296301
self._update_voxel_levels()
297302
self._draw()
298303

299-
def _on_mousemove(self, event):
304+
def _on_mouse(self, event):
305+
"""Handle mpl mouse move and button press events"""
300306
if event.button != 1: # only enabled while dragging
301307
return
302308
key = self._in_axis(event)
@@ -312,18 +318,18 @@ def _on_mousemove(self, event):
312318
self._draw()
313319

314320
def _on_keypress(self, event):
321+
"""Handle mpl keypress events"""
315322
if event.key is not None and 'escape' in event.key:
316323
self.close()
317324

318325
def _draw(self):
319-
for im in self._ims.values():
320-
ax = im.axes
326+
"""Update all four (or three) plots"""
327+
for key in 'xyz':
328+
ax, im = self._axes[key], self._ims[key]
321329
ax.draw_artist(im)
322-
ax.draw_artist(im.vert_line)
323-
ax.draw_artist(im.horiz_line)
330+
for line in self._crosshairs[key].values():
331+
ax.draw_artist(line)
324332
ax.figure.canvas.blit(ax.bbox)
325-
for t in im.texts:
326-
ax.draw_artist(t)
327333
if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3
328334
ax = self._axes['v']
329335
ax.draw_artist(ax.patch) # axis bgcolor to erase old lines

0 commit comments

Comments
 (0)