Skip to content

Commit f9c6ae0

Browse files
committed
ENH: Better tests
1 parent e55a15e commit f9c6ae0

File tree

3 files changed

+49
-39
lines changed

3 files changed

+49
-39
lines changed

nibabel/spatialimages.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -663,20 +663,18 @@ def __getitem__(self):
663663
"array data with `img.dataobj[slice]` or "
664664
"`img.get_data()[slice]`")
665665

666-
def plot(self, show=True):
666+
def plot(self):
667667
"""Plot the image using OrthoSlicer3D
668668
669-
Parameters
670-
----------
671-
show : bool
672-
If True, the viewer will be shown.
673-
674669
Returns
675670
-------
676671
viewer : instance of OrthoSlicer3D
677672
The viewer.
673+
674+
Notes
675+
-----
676+
This requires matplotlib. If a non-interactive backend is used,
677+
consider using viewer.show() (equivalently plt.show()) to show
678+
the figure.
678679
"""
679-
out = OrthoSlicer3D(self.get_data())
680-
if show:
681-
out.show()
682-
return out
680+
return OrthoSlicer3D(self.get_data())

nibabel/tests/test_viewers.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,39 @@ def test_viewer():
2626
# Test viewer
2727
a = np.sin(np.linspace(0, np.pi, 20))
2828
b = np.sin(np.linspace(0, np.pi*5, 30))
29-
data = np.outer(a, b)[..., np.newaxis] * a
29+
data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis]
3030
viewer = OrthoSlicer3D(data)
3131
plt.draw()
3232

33-
# fake some events
34-
viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes
35-
viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes
33+
# fake some events, inside and outside axes
34+
viewer._on_scroll(nt('event', 'button inaxes key')('up', None, None))
35+
for ax in (viewer._axes['x'], viewer._axes['v']):
36+
viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, None))
37+
viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift'))
3638
# "click" outside axes, then once in each axis, then move without click
37-
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
38-
None, 1))
39-
for im in viewer._ims:
40-
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
41-
im.axes,
42-
1))
43-
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
44-
None, None))
39+
viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
40+
None, 1))
41+
for ax in viewer._axes.values():
42+
viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
43+
ax, 1))
44+
viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
45+
None, None))
4546
viewer.set_indices(0, 1, 2)
47+
viewer.set_indices(v=10)
4648
viewer.close()
4749

50+
# non-multi-volume
51+
viewer = OrthoSlicer3D(data[:, :, :, 0])
52+
assert_raises(ValueError, viewer.set_indices, v=10) # not multi-volume
53+
viewer._on_scroll(nt('event', 'button inaxes key')('up', viewer._axes['x'],
54+
'shift'))
55+
viewer._on_keypress(nt('event', 'key')('escape'))
56+
4857
# other cases
49-
fig, axes = plt.subplots(1, 3)
58+
fig, axes = plt.subplots(1, 4)
5059
plt.close(fig)
51-
OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes)
52-
assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2, 3],
53-
axes=axes)
60+
OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes,
61+
aspect_ratio=[1, 2, 3])
62+
OrthoSlicer3D(data, axes=axes[:3])
63+
assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2])
64+
assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0])

nibabel/viewers.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
plt, _, _ = optional_package('matplotlib.pyplot')
1313
mpl_img, _, _ = optional_package('matplotlib.image')
14+
mpl_patch, _, _ = optional_package('matplotlib.patches')
1415

1516

1617
class OrthoSlicer3D(object):
@@ -61,7 +62,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
6162
aspect_ratio = dict(x=ar[0], y=ar[1], z=ar[2])
6263
data = np.asanyarray(data)
6364
if data.ndim < 3:
64-
raise RuntimeError('data must have at least 3 dimensions')
65+
raise ValueError('data must have at least 3 dimensions')
6566
self._volume_dims = data.shape[3:]
6667
self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data
6768
self._data = data
@@ -147,20 +148,23 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
147148
ax.axes.get_xaxis().set_visible(False)
148149

149150
# Set up volumes axis
150-
if self.multi_volume:
151+
if self.multi_volume and 'v' in self._axes:
151152
ax = self._axes['v']
152153
ax.set_axis_bgcolor('k')
153154
ax.set_title('Volumes')
154155
n_vols = np.prod(self._volume_dims)
155-
print(n_vols)
156156
y = np.mean(np.mean(np.mean(self._data, 0), 0), 0).ravel()
157157
y = np.concatenate((y, [y[-1]]))
158158
x = np.arange(n_vols + 1) - 0.5
159159
step = ax.step(x, y, where='post', color='y')[0]
160160
ax.set_xticks(np.unique(np.linspace(0, n_vols - 1, 5).astype(int)))
161161
ax.set_xlim(x[0], x[-1])
162-
line = ax.plot([0, 0], ax.get_ylim(), color=(0, 1, 0))[0]
163-
self._time_lines = [line, step]
162+
lims = ax.get_ylim()
163+
patch = mpl_patch.Rectangle([-0.5, lims[0]], 1., np.diff(lims)[0],
164+
fill=True, facecolor=(0, 1, 0),
165+
edgecolor=(0, 1, 0), alpha=0.25)
166+
ax.add_patch(patch)
167+
self._time_lines = [patch, step]
164168

165169
# setup pairwise connections between the slice dimensions
166170
self._click_update_keys = dict(x='yz', y='xz', z='xy')
@@ -180,11 +184,9 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
180184
fig.canvas.mpl_connect('motion_notify_event', self._on_mousemove)
181185
fig.canvas.mpl_connect('button_press_event', self._on_mousemove)
182186
fig.canvas.mpl_connect('key_press_event', self._on_keypress)
183-
plt.draw()
184-
self._draw()
185187

186188
def show(self):
187-
""" Show the slicer; convenience for ``plt.show()``
189+
""" Show the slicer in blocking mode; convenience for ``plt.show()``
188190
"""
189191
plt.show()
190192

@@ -220,8 +222,8 @@ def set_indices(self, x=None, y=None, z=None, v=None):
220222
draw = False
221223
if v is not None:
222224
if not self.multi_volume:
223-
raise RuntimeError('cannot change volume index of '
224-
'single-volume image')
225+
raise ValueError('cannot change volume index of single-volume '
226+
'image')
225227
self._set_vol_idx(v, draw=False) # delay draw
226228
draw = True
227229
for key, val in zip('zyx', (z, y, x)):
@@ -239,7 +241,7 @@ def _set_vol_idx(self, idx, draw=True):
239241
self._current_vol_data = self._data[:, :, :, self._idx['v']]
240242
for key in 'xyz':
241243
self._ims[key].set_data(self._get_slice(key))
242-
self._time_lines[0].set_xdata([self._idx['v']] * 2)
244+
self._time_lines[0].set_x(self._idx['v'] - 0.5)
243245
if draw:
244246
self._draw()
245247

@@ -262,7 +264,6 @@ def _in_axis(self, event):
262264
for key, ax in self._axes.items():
263265
if event.inaxes is ax:
264266
return key
265-
return None
266267

267268
def _on_scroll(self, event):
268269
assert event.button in ('up', 'down')
@@ -309,7 +310,7 @@ def _draw(self):
309310
ax.figure.canvas.blit(ax.bbox)
310311
for t in im.texts:
311312
ax.draw_artist(t)
312-
if self.multi_volume:
313+
if self.multi_volume and 'v' in self._axes: # user might only pass 3
313314
ax = self._axes['v']
314315
ax.draw_artist(ax.patch)
315316
for artist in self._time_lines:

0 commit comments

Comments
 (0)