Skip to content

Commit 11ff239

Browse files
committed
FIX: Update volume plot
1 parent f9c6ae0 commit 11ff239

File tree

2 files changed

+47
-32
lines changed

2 files changed

+47
-32
lines changed

nibabel/tests/test_viewers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_viewer():
2727
a = np.sin(np.linspace(0, np.pi, 20))
2828
b = np.sin(np.linspace(0, np.pi*5, 30))
2929
data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis]
30+
data = data * np.array([1., 2.]) # give it a # of volumes > 1
3031
viewer = OrthoSlicer3D(data)
3132
plt.draw()
3233

nibabel/viewers.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
9393
self._axes = dict(x=axes[0, 1], y=axes[0, 0], z=axes[1, 0],
9494
v=axes[1, 1])
9595
plt.tight_layout(pad=0.1)
96-
if not self.multi_volume:
96+
if self.n_volumes <= 1:
9797
fig.delaxes(self._axes['v'])
9898
del self._axes['v']
9999
else:
@@ -109,7 +109,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
109109
colors = dict()
110110
for k, size in zip('xyz', self._data.shape[:3]):
111111
self._idx[k] = size // 2
112-
self._ims[k] = self._axes[k].imshow(self._get_slice(k), **kw)
112+
self._ims[k] = self._axes[k].imshow(self._get_slice_data(k), **kw)
113113
self._sizes[k] = size
114114
colors[k] = (0, 1, 0)
115115
self._idx['v'] = 0
@@ -148,23 +148,24 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
148148
ax.axes.get_xaxis().set_visible(False)
149149

150150
# Set up volumes axis
151-
if self.multi_volume and 'v' in self._axes:
151+
if self.n_volumes > 1 and 'v' in self._axes:
152152
ax = self._axes['v']
153153
ax.set_axis_bgcolor('k')
154154
ax.set_title('Volumes')
155-
n_vols = np.prod(self._volume_dims)
156-
y = np.mean(np.mean(np.mean(self._data, 0), 0), 0).ravel()
157-
y = np.concatenate((y, [y[-1]]))
158-
x = np.arange(n_vols + 1) - 0.5
155+
y = self._get_voxel_levels()
156+
x = np.arange(self.n_volumes + 1) - 0.5
159157
step = ax.step(x, y, where='post', color='y')[0]
160-
ax.set_xticks(np.unique(np.linspace(0, n_vols - 1, 5).astype(int)))
158+
ax.set_xticks(np.unique(np.linspace(0, self.n_volumes - 1,
159+
5).astype(int)))
161160
ax.set_xlim(x[0], x[-1])
162-
lims = ax.get_ylim()
163-
patch = mpl_patch.Rectangle([-0.5, lims[0]], 1., np.diff(lims)[0],
164-
fill=True, facecolor=(0, 1, 0),
165-
edgecolor=(0, 1, 0), alpha=0.25)
161+
yl = [self._data.min(), self._data.max()]
162+
yl = [l + s * np.diff(lims)[0] for l, s in zip(yl, [-1.01, 1.01])]
163+
patch = mpl_patch.Rectangle([-0.5, yl[0]], 1., np.diff(yl)[0],
164+
fill=True, facecolor=(0, 1, 0),
165+
edgecolor=(0, 1, 0), alpha=0.25)
166166
ax.add_patch(patch)
167-
self._time_lines = [patch, step]
167+
ax.set_ylim(yl)
168+
self._volume_ax_objs = dict(step=step, patch=patch)
168169

169170
# setup pairwise connections between the slice dimensions
170171
self._click_update_keys = dict(x='yz', y='xz', z='xy')
@@ -197,9 +198,9 @@ def close(self):
197198
plt.close(f)
198199

199200
@property
200-
def multi_volume(self):
201-
"""Whether or not the displayed data is multi-volume"""
202-
return len(self._volume_dims) > 0
201+
def n_volumes(self):
202+
"""Number of volumes in the data"""
203+
return int(np.prod(self._volume_dims))
203204

204205
def set_indices(self, x=None, y=None, z=None, v=None):
205206
"""Set current displayed slice indices
@@ -221,39 +222,51 @@ def set_indices(self, x=None, y=None, z=None, v=None):
221222
v = int(v) if v is not None else None
222223
draw = False
223224
if v is not None:
224-
if not self.multi_volume:
225+
if self.n_volumes <= 1:
225226
raise ValueError('cannot change volume index of single-volume '
226227
'image')
227-
self._set_vol_idx(v, draw=False) # delay draw
228+
self._set_vol_idx(v)
228229
draw = True
229230
for key, val in zip('zyx', (z, y, x)):
230231
if val is not None:
231232
self._set_viewer_slice(key, val)
232233
draw = True
233234
if draw:
235+
self._update_voxel_levels()
234236
self._draw()
235237

236-
def _set_vol_idx(self, idx, draw=True):
237-
"""Helper to change which volume is shown"""
238+
def _get_voxel_levels(self):
239+
"""Get levels of the current voxel as a function of volume"""
240+
y = self._data[self._idx['x'],
241+
self._idx['y'],
242+
self._idx['z'], :].ravel()
243+
y = np.concatenate((y, [y[-1]]))
244+
return y
245+
246+
def _update_voxel_levels(self):
247+
"""Update voxel levels in time plot"""
248+
if self.n_volumes > 1:
249+
self._volume_ax_objs['step'].set_ydata(self._get_voxel_levels())
250+
251+
def _set_vol_idx(self, idx):
252+
"""Change which volume is shown"""
238253
max_ = np.prod(self._volume_dims)
239254
self._idx['v'] = max(min(int(round(idx)), max_ - 1), 0)
240255
# Must reset what is shown
241256
self._current_vol_data = self._data[:, :, :, self._idx['v']]
242257
for key in 'xyz':
243-
self._ims[key].set_data(self._get_slice(key))
244-
self._time_lines[0].set_x(self._idx['v'] - 0.5)
245-
if draw:
246-
self._draw()
258+
self._ims[key].set_data(self._get_slice_data(key))
259+
self._volume_ax_objs['patch'].set_x(self._idx['v'] - 0.5)
247260

248-
def _get_slice(self, key):
261+
def _get_slice_data(self, key):
249262
"""Helper to get the current slice image"""
250263
ii = dict(x=0, y=1, z=2)[key]
251264
return np.take(self._current_vol_data, self._idx[key], axis=ii).T
252265

253266
def _set_viewer_slice(self, key, idx):
254267
"""Helper to set a viewer slice number"""
255268
self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0)
256-
self._ims[key].set_data(self._get_slice(key))
269+
self._ims[key].set_data(self._get_slice_data(key))
257270
for fun in self._cross_setters[key]:
258271
fun([self._idx[key]] * 2)
259272

@@ -272,14 +285,15 @@ def _on_scroll(self, event):
272285
return
273286
delta = 10 if event.key is not None and 'control' in event.key else 1
274287
if event.key is not None and 'shift' in event.key:
275-
if not self.multi_volume:
288+
if self.n_volumes <= 1:
276289
return
277290
key = 'v' # shift: change volume in any axis
278291
idx = self._idx[key] + (delta if event.button == 'up' else -delta)
279292
if key == 'v':
280293
self._set_vol_idx(idx)
281294
else:
282295
self._set_viewer_slice(key, idx)
296+
self._update_voxel_levels()
283297
self._draw()
284298

285299
def _on_mousemove(self, event):
@@ -294,6 +308,7 @@ def _on_mousemove(self, event):
294308
for sub_key, idx in zip(self._click_update_keys[key],
295309
(event.xdata, event.ydata)):
296310
self._set_viewer_slice(sub_key, idx)
311+
self._update_voxel_levels()
297312
self._draw()
298313

299314
def _on_keypress(self, event):
@@ -303,16 +318,15 @@ def _on_keypress(self, event):
303318
def _draw(self):
304319
for im in self._ims.values():
305320
ax = im.axes
306-
ax.draw_artist(ax.patch)
307321
ax.draw_artist(im)
308322
ax.draw_artist(im.vert_line)
309323
ax.draw_artist(im.horiz_line)
310324
ax.figure.canvas.blit(ax.bbox)
311325
for t in im.texts:
312326
ax.draw_artist(t)
313-
if self.multi_volume and 'v' in self._axes: # user might only pass 3
327+
if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3
314328
ax = self._axes['v']
315-
ax.draw_artist(ax.patch)
316-
for artist in self._time_lines:
317-
ax.draw_artist(artist)
329+
ax.draw_artist(ax.patch) # axis bgcolor to erase old lines
330+
for key in ('step', 'patch'):
331+
ax.draw_artist(self._volume_ax_objs[key])
318332
ax.figure.canvas.blit(ax.bbox)

0 commit comments

Comments
 (0)