Skip to content

Commit 6fda4d8

Browse files
committed
ENH: Add set_indices method
1 parent 5097571 commit 6fda4d8

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

nibabel/tests/test_viewers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ def test_viewer():
3333
# fake some events
3434
viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes
3535
viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes
36-
# tracking on
36+
# "click" outside axes, then once in each axis, then move without click
3737
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
3838
None, 1))
39-
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
40-
plt.gca(), 1))
41-
# tracking off
39+
for im in viewer._ims:
40+
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
41+
im.axes,
42+
1))
4243
viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5,
4344
None, None))
45+
viewer.set_indices(0, 1, 2)
4446
viewer.close()
4547

4648
# other cases

nibabel/viewers.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535

3636
def _set_viewer_slice(idx, im):
3737
"""Helper to set a viewer slice number"""
38-
im.idx = idx
38+
im.idx = max(min(int(round(idx)), im.size - 1), 0)
3939
im.set_data(im.get_slice(im.idx))
4040
for fun in im.cross_setters:
41-
fun([idx] * 2)
41+
fun([im.idx] * 2)
4242

4343

4444
class OrthoSlicer3D(object):
@@ -133,24 +133,21 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
133133
origin='lower')
134134

135135
# Start midway through each axis
136-
st_x, st_y, st_z = (data_shape - 1) / 2.
137-
sts = (st_x, st_y, st_z)
138-
n_x, n_y, n_z = data_shape
139-
z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T
140-
y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T
141-
x_get_slice = lambda i: self.data[min(i, n_x-1), :, :].T
142-
im1 = ax1.imshow(z_get_slice(st_z), **kw)
143-
im2 = ax2.imshow(y_get_slice(st_y), **kw)
144-
im3 = ax3.imshow(x_get_slice(st_x), **kw)
136+
z_get_slice = lambda i: self.data[:, :, i].T
137+
y_get_slice = lambda i: self.data[:, i, :].T
138+
x_get_slice = lambda i: self.data[i, :, :].T
139+
sts = (data_shape - 1) // 2
140+
im1 = ax1.imshow(z_get_slice(sts[2]), **kw)
141+
im2 = ax2.imshow(y_get_slice(sts[1]), **kw)
142+
im3 = ax3.imshow(x_get_slice(sts[0]), **kw)
143+
# idx is the current slice number for each panel
144+
im1.idx, im2.idx, im3.idx = sts
145+
self._ims = (im1, im2, im3)
145146
im1.get_slice, im2.get_slice, im3.get_slice = (
146147
z_get_slice, y_get_slice, x_get_slice)
147-
self._ims = (im1, im2, im3)
148-
149-
# idx is the current slice number for each panel
150-
im1.idx, im2.idx, im3.idx = st_z, st_y, st_x
151148

152149
# set the maximum dimensions for indexing
153-
im1.size, im2.size, im3.size = n_z, n_y, n_x
150+
im1.size, im2.size, im3.size = data_shape
154151

155152
# set up axis crosshairs
156153
colors = ['r', 'g', 'b']
@@ -191,6 +188,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
191188
for fig in self.figs:
192189
fig.canvas.mpl_connect('scroll_event', self.on_scroll)
193190
fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove)
191+
fig.canvas.mpl_connect('button_press_event', self.on_mousemove)
194192

195193
def show(self):
196194
""" Show the slicer; convenience for ``plt.show()``
@@ -203,6 +201,26 @@ def close(self):
203201
for f in self.figs:
204202
plt.close(f)
205203

204+
def set_indices(self, x=None, y=None, z=None):
205+
"""Set current displayed slice indices
206+
207+
Parameters
208+
----------
209+
x : int | None
210+
Index to use. If None, do not change.
211+
y : int | None
212+
Index to use. If None, do not change.
213+
z : int | None
214+
Index to use. If None, do not change.
215+
"""
216+
draw = False
217+
for im, val in zip(self._ims, (z, y, x)):
218+
if val is not None:
219+
im.set_viewer_slice(val)
220+
draw = True
221+
if draw:
222+
self._draw_ims()
223+
206224
def _axis_artist(self, event):
207225
"""Return artist if within axes, and is an image, else None
208226
"""
@@ -216,8 +234,7 @@ def on_scroll(self, event):
216234
im = self._axis_artist(event)
217235
if im is None:
218236
return
219-
idx = (im.idx + (1 if event.button == 'up' else -1))
220-
idx = max(min(idx, im.size - 1), 0)
237+
idx = im.idx + (1 if event.button == 'up' else -1)
221238
im.set_viewer_slice(idx)
222239
self._draw_ims()
223240

@@ -227,9 +244,7 @@ def on_mousemove(self, event):
227244
im = self._axis_artist(event)
228245
if im is None:
229246
return
230-
x_im, y_im = im.x_im, im.y_im
231-
x, y = np.round((event.xdata, event.ydata)).astype(int)
232-
for i, idx in zip((x_im, y_im), (x, y)):
247+
for i, idx in zip((im.x_im, im.y_im), (event.xdata, event.ydata)):
233248
i.set_viewer_slice(idx)
234249
self._draw_ims()
235250

0 commit comments

Comments
 (0)