|
| 1 | +""" Utilities for viewing images |
| 2 | +
|
| 3 | +Includes version of OrthoSlicer3D code by our own Paul Ivanov |
| 4 | +""" |
| 5 | +from __future__ import division, print_function |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from .optpkg import optional_package |
| 10 | + |
| 11 | +plt, _, _ = optional_package('matplotlib.pyplot') |
| 12 | +mpl_img, _, _ = optional_package('matplotlib.image') |
| 13 | + |
| 14 | +# Assumes the following layout |
| 15 | +# |
| 16 | +# ^ +---------+ ^ +---------+ |
| 17 | +# | | | | | | |
| 18 | +# | | | | |
| 19 | +# z | 2 | z | 3 | |
| 20 | +# | | | | |
| 21 | +# | | | | | | |
| 22 | +# v +---------+ v +---------+ |
| 23 | +# <-- x --> <-- y --> |
| 24 | +# ^ +---------+ |
| 25 | +# | | | |
| 26 | +# | | |
| 27 | +# y | 1 | |
| 28 | +# | | |
| 29 | +# | | | |
| 30 | +# v +---------+ |
| 31 | +# <-- x --> |
| 32 | + |
| 33 | +class OrthoSlicer3D(object): |
| 34 | + """Orthogonal-plane slicer. |
| 35 | +
|
| 36 | + OrthoSlicer3d expects 3-dimensional data, and by default it creates a |
| 37 | + figure with 3 axes, one for each slice orientation. |
| 38 | +
|
| 39 | + There are two modes, "following on" and "following off". In "following on" |
| 40 | + mode, moving the mouse in any one axis will select out the corresponding |
| 41 | + slices in the other two. The mode is "following off" when the figure is |
| 42 | + first created. Clicking the left mouse button toggles mouse following and |
| 43 | + triggers a full redraw (to update the ticks, for example). Scrolling up and |
| 44 | + down moves the slice up and down in the current axis. |
| 45 | +
|
| 46 | + Example |
| 47 | + ------- |
| 48 | + import numpy as np |
| 49 | + a = np.sin(np.linspace(0,np.pi,20)) |
| 50 | + b = np.sin(np.linspace(0,np.pi*5,20)) |
| 51 | + data = np.outer(a,b)[..., np.newaxis]*a |
| 52 | + OrthoSlicer3D(data).show() |
| 53 | + """ |
| 54 | + def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', |
| 55 | + pcnt_range=None): |
| 56 | + """ |
| 57 | + Parameters |
| 58 | + ---------- |
| 59 | + data : 3 dimensional ndarray |
| 60 | + The data that will be displayed by the slicer |
| 61 | + axes : None or length 3 sequence of mpl.Axes, optional |
| 62 | + 3 axes instances for the X, Y, and Z slices, or None (default) |
| 63 | + aspect_ratio : float or length 3 sequence, optional |
| 64 | + stretch factors for X, Y, Z directions |
| 65 | + cmap : colormap identifier, optional |
| 66 | + String or cmap instance specifying colormap. Will be passed as |
| 67 | + ``cmap`` argument to ``plt.imshow``. |
| 68 | + pcnt_range : length 2 sequence, optional |
| 69 | + Percentile range over which to scale image for display. If None, |
| 70 | + scale between image mean and max. If sequence, min and max |
| 71 | + percentile over which to scale image. |
| 72 | + """ |
| 73 | + data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension |
| 74 | + aspect_ratio = np.array(aspect_ratio) |
| 75 | + if axes is None: # make the axes |
| 76 | + # ^ +---------+ ^ +---------+ |
| 77 | + # | | | | | | |
| 78 | + # | | | | |
| 79 | + # z | 2 | z | 3 | |
| 80 | + # | | | | |
| 81 | + # | | | | | | |
| 82 | + # v +---------+ v +---------+ |
| 83 | + # <-- x --> <-- y --> |
| 84 | + # ^ +---------+ |
| 85 | + # | | | |
| 86 | + # | | |
| 87 | + # y | 1 | |
| 88 | + # | | |
| 89 | + # | | | |
| 90 | + # v +---------+ |
| 91 | + # <-- x --> |
| 92 | + fig = plt.figure() |
| 93 | + x, y, z = data_shape * aspect_ratio |
| 94 | + maxw = float(x + y) |
| 95 | + maxh = float(y + z) |
| 96 | + yh = y / maxh |
| 97 | + xw = x / maxw |
| 98 | + yw = y / maxw |
| 99 | + zh = z / maxh |
| 100 | + # z slice (if usual transverse acquisition => axial slice) |
| 101 | + ax1 = fig.add_axes((0., 0., xw, yh)) |
| 102 | + # y slice (usually coronal) |
| 103 | + ax2 = fig.add_axes((0, yh, xw, zh)) |
| 104 | + # x slice (usually sagittal) |
| 105 | + ax3 = fig.add_axes((xw, yh, yw, zh)) |
| 106 | + axes = (ax1, ax2, ax3) |
| 107 | + else: |
| 108 | + if not np.all(aspect_ratio == 1): |
| 109 | + raise ValueError('Aspect ratio must be 1 for external axes') |
| 110 | + ax1, ax2, ax3 = axes |
| 111 | + |
| 112 | + self.data = data |
| 113 | + |
| 114 | + if pcnt_range is None: |
| 115 | + vmin, vmax = data.min(), data.max() |
| 116 | + else: |
| 117 | + vmin, vmax = np.percentile(data, pcnt_range) |
| 118 | + |
| 119 | + kw = dict(vmin=vmin, |
| 120 | + vmax=vmax, |
| 121 | + aspect='auto', |
| 122 | + interpolation='nearest', |
| 123 | + cmap=cmap, |
| 124 | + origin='lower') |
| 125 | + # Start midway through each axis |
| 126 | + st_x, st_y, st_z = (data_shape - 1) / 2. |
| 127 | + n_x, n_y, n_z = data_shape |
| 128 | + z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T |
| 129 | + y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T |
| 130 | + x_get_slice = lambda i: self.data[min(i, n_x-1), :, :].T |
| 131 | + im1 = ax1.imshow(z_get_slice(st_z), **kw) |
| 132 | + im2 = ax2.imshow(y_get_slice(st_y), **kw) |
| 133 | + im3 = ax3.imshow(x_get_slice(st_x), **kw) |
| 134 | + im1.get_slice, im2.get_slice, im3.get_slice = ( |
| 135 | + z_get_slice, y_get_slice, x_get_slice) |
| 136 | + # idx is the current slice number for each panel |
| 137 | + im1.idx, im2.idx, im3.idx = st_z, st_y, st_x |
| 138 | + # set the maximum dimensions for indexing |
| 139 | + im1.size, im2.size, im3.size = n_z, n_y, n_x |
| 140 | + # setup pairwise connections between the slice dimensions |
| 141 | + im1.imx = im3 # x move in panel 1 (usually axial) |
| 142 | + im1.imy = im2 # y move in panel 1 |
| 143 | + im2.imx = im3 # x move in panel 2 (usually coronal) |
| 144 | + im2.imy = im1 |
| 145 | + im3.imx = im2 # x move in panel 3 (usually sagittal) |
| 146 | + im3.imy = im1 |
| 147 | + |
| 148 | + self.follow = False |
| 149 | + self.figs = set([ax.figure for ax in axes]) |
| 150 | + for fig in self.figs: |
| 151 | + fig.canvas.mpl_connect('button_press_event', self.on_click) |
| 152 | + fig.canvas.mpl_connect('scroll_event', self.on_scroll) |
| 153 | + fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) |
| 154 | + |
| 155 | + def show(self): |
| 156 | + """ Show the slicer; convenience for ``plt.show()`` |
| 157 | + """ |
| 158 | + plt.show() |
| 159 | + |
| 160 | + def _axis_artist(self, event): |
| 161 | + """ Return artist if within axes, and is an image, else None |
| 162 | + """ |
| 163 | + if not getattr(event, 'inaxes'): |
| 164 | + return None |
| 165 | + artist = event.inaxes.images[0] |
| 166 | + return artist if isinstance(artist, mpl_img.AxesImage) else None |
| 167 | + |
| 168 | + def on_click(self, event): |
| 169 | + if event.button == 1: |
| 170 | + self.follow = not self.follow |
| 171 | + plt.draw() |
| 172 | + |
| 173 | + def on_scroll(self, event): |
| 174 | + assert event.button in ('up', 'down') |
| 175 | + im = self._axis_artist(event) |
| 176 | + if im is None: |
| 177 | + return |
| 178 | + im.idx += 1 if event.button == 'up' else -1 |
| 179 | + im.idx %= im.size |
| 180 | + im.set_data(im.get_slice(im.idx)) |
| 181 | + ax = im.axes |
| 182 | + ax.draw_artist(im) |
| 183 | + ax.figure.canvas.blit(ax.bbox) |
| 184 | + |
| 185 | + def on_mousemove(self, event): |
| 186 | + if not self.follow: |
| 187 | + return |
| 188 | + im = self._axis_artist(event) |
| 189 | + if im is None: |
| 190 | + return |
| 191 | + ax = im.axes |
| 192 | + imx, imy = im.imx, im.imy |
| 193 | + x, y = np.round((event.xdata, event.ydata)).astype(int) |
| 194 | + imx.set_data(imx.get_slice(x)) |
| 195 | + imy.set_data(imy.get_slice(y)) |
| 196 | + imx.idx = x |
| 197 | + imy.idx = y |
| 198 | + for i in imx, imy: |
| 199 | + ax = i.axes |
| 200 | + ax.draw_artist(i) |
| 201 | + ax.figure.canvas.blit(ax.bbox) |
| 202 | + |
| 203 | + |
| 204 | +if __name__ == '__main__': |
| 205 | + a = np.sin(np.linspace(0,np.pi,20)) |
| 206 | + b = np.sin(np.linspace(0,np.pi*5,20)) |
| 207 | + data = np.outer(a,b)[..., np.newaxis]*a |
| 208 | + # all slices |
| 209 | + OrthoSlicer3D(data).show() |
| 210 | + |
| 211 | + # broken out into three separate figures |
| 212 | + f, ax1 = plt.subplots() |
| 213 | + f, ax2 = plt.subplots() |
| 214 | + f, ax3 = plt.subplots() |
| 215 | + OrthoSlicer3D(data, axes=(ax1, ax2, ax3)).show() |
0 commit comments