1
1
""" Utilities for viewing images
2
2
3
- Includes version of OrthoSlicer3D code by our own Paul Ivanov
3
+ Includes version of OrthoSlicer3D code originally written by our own
4
+ Paul Ivanov.
4
5
"""
5
6
from __future__ import division , print_function
6
7
7
8
import numpy as np
9
+ from functools import partial
8
10
9
11
from .optpkg import optional_package
10
12
30
32
# v +---------+
31
33
# <-- x -->
32
34
35
+
36
+ def _set_viewer_slice (idx , im ):
37
+ """Helper to set a viewer slice number"""
38
+ im .idx = idx
39
+ im .set_data (im .get_slice (im .idx ))
40
+ for fun in im .cross_setters :
41
+ fun ([idx ] * 2 )
42
+
43
+
33
44
class OrthoSlicer3D (object ):
34
45
"""Orthogonal-plane slicer.
35
46
36
47
OrthoSlicer3d expects 3-dimensional data, and by default it creates a
37
48
figure with 3 axes, one for each slice orientation.
38
49
39
- There are two modes, "following on" and "following off". In "following on"
40
- mode, moving the mouse in any one axis will select out the corresponding
41
- slices in the other two. The mode is "following off" when the figure is
42
- first created. Clicking the left mouse button toggles mouse following and
43
- triggers a full redraw (to update the ticks, for example). Scrolling up and
50
+ Clicking and dragging the mouse in any one axis will select out the
51
+ corresponding slices in the other two. Scrolling up and
44
52
down moves the slice up and down in the current axis.
45
53
46
54
Example
47
55
-------
48
- import numpy as np
49
- a = np.sin(np.linspace(0,np.pi,20))
50
- b = np.sin(np.linspace(0,np.pi*5,20))
51
- data = np.outer(a,b)[..., np.newaxis]*a
52
- OrthoSlicer3D(data).show()
56
+ >>> import numpy as np
57
+ >>> a = np.sin(np.linspace(0,np.pi,20))
58
+ >>> b = np.sin(np.linspace(0,np.pi*5,20))
59
+ >>> data = np.outer(a,b)[..., np.newaxis]*a
60
+ >>> OrthoSlicer3D(data).show()
53
61
"""
54
62
def __init__ (self , data , axes = None , aspect_ratio = (1 , 1 , 1 ), cmap = 'gray' ,
55
63
pcnt_range = None ):
@@ -70,9 +78,9 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
70
78
scale between image mean and max. If sequence, min and max
71
79
percentile over which to scale image.
72
80
"""
73
- data_shape = np .array (data .shape [:3 ]) # allow trailing RGB dimension
74
- aspect_ratio = np .array (aspect_ratio )
75
- if axes is None : # make the axes
81
+ data_shape = np .array (data .shape [:3 ]) # allow trailing RGB dimension
82
+ aspect_ratio = np .array (aspect_ratio , float )
83
+ if axes is None : # make the axes
76
84
# ^ +---------+ ^ +---------+
77
85
# | | | | | |
78
86
# | | | |
@@ -122,8 +130,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
122
130
interpolation = 'nearest' ,
123
131
cmap = cmap ,
124
132
origin = 'lower' )
133
+
125
134
# Start midway through each axis
126
135
st_x , st_y , st_z = (data_shape - 1 ) / 2.
136
+ sts = (st_x , st_y , st_z )
127
137
n_x , n_y , n_z = data_shape
128
138
z_get_slice = lambda i : self .data [:, :, min (i , n_z - 1 )].T
129
139
y_get_slice = lambda i : self .data [:, min (i , n_y - 1 ), :].T
@@ -133,22 +143,51 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
133
143
im3 = ax3 .imshow (x_get_slice (st_x ), ** kw )
134
144
im1 .get_slice , im2 .get_slice , im3 .get_slice = (
135
145
z_get_slice , y_get_slice , x_get_slice )
146
+ self ._ims = (im1 , im2 , im3 )
147
+
136
148
# idx is the current slice number for each panel
137
149
im1 .idx , im2 .idx , im3 .idx = st_z , st_y , st_x
150
+
138
151
# set the maximum dimensions for indexing
139
152
im1 .size , im2 .size , im3 .size = n_z , n_y , n_x
153
+
154
+ # set up axis crosshairs
155
+ colors = ['r' , 'g' , 'b' ]
156
+ for ax , im , idx_1 , idx_2 in zip (axes , self ._ims , [0 , 0 , 1 ], [1 , 2 , 2 ]):
157
+ im .x_line = ax .plot ([sts [idx_1 ]] * 2 ,
158
+ [- 0.5 , data .shape [idx_2 ] - 0.5 ],
159
+ color = colors [idx_1 ], linestyle = '-' ,
160
+ alpha = 0.25 )[0 ]
161
+ im .y_line = ax .plot ([- 0.5 , data .shape [idx_1 ] - 0.5 ],
162
+ [sts [idx_2 ]] * 2 ,
163
+ color = colors [idx_2 ], linestyle = '-' ,
164
+ alpha = 0.25 )[0 ]
165
+ ax .axis ('tight' )
166
+ ax .patch .set_visible (False )
167
+ ax .set_frame_on (False )
168
+ ax .axes .get_yaxis ().set_visible (False )
169
+ ax .axes .get_xaxis ().set_visible (False )
170
+
171
+ # monkey-patch some functions
172
+ im1 .set_viewer_slice = partial (_set_viewer_slice , im = im1 )
173
+ im2 .set_viewer_slice = partial (_set_viewer_slice , im = im2 )
174
+ im3 .set_viewer_slice = partial (_set_viewer_slice , im = im3 )
175
+
140
176
# setup pairwise connections between the slice dimensions
141
- im1 .imx = im3 # x move in panel 1 (usually axial)
142
- im1 .imy = im2 # y move in panel 1
143
- im2 .imx = im3 # x move in panel 2 (usually coronal)
144
- im2 .imy = im1
145
- im3 .imx = im2 # x move in panel 3 (usually sagittal)
146
- im3 .imy = im1
147
-
148
- self .follow = False
177
+ im1 .x_im = im3 # x move in panel 1 (usually axial)
178
+ im1 .y_im = im2 # y move in panel 1
179
+ im2 .x_im = im3 # x move in panel 2 (usually coronal)
180
+ im2 .y_im = im1 # y move in panel 2
181
+ im3 .x_im = im2 # x move in panel 3 (usually sagittal)
182
+ im3 .y_im = im1 # y move in panel 3
183
+
184
+ # when an index changes, which crosshairs need to be updated
185
+ im1 .cross_setters = [im2 .y_line .set_ydata , im3 .y_line .set_ydata ]
186
+ im2 .cross_setters = [im1 .y_line .set_ydata , im3 .x_line .set_xdata ]
187
+ im3 .cross_setters = [im1 .x_line .set_xdata , im2 .x_line .set_xdata ]
188
+
149
189
self .figs = set ([ax .figure for ax in axes ])
150
190
for fig in self .figs :
151
- fig .canvas .mpl_connect ('button_press_event' , self .on_click )
152
191
fig .canvas .mpl_connect ('scroll_event' , self .on_scroll )
153
192
fig .canvas .mpl_connect ('motion_notify_event' , self .on_mousemove )
154
193
@@ -157,59 +196,46 @@ def show(self):
157
196
"""
158
197
plt .show ()
159
198
199
+ def close (self ):
200
+ """Close the viewer figures
201
+ """
202
+ for f in self .figs :
203
+ plt .close (f )
204
+
160
205
def _axis_artist (self , event ):
161
- """ Return artist if within axes, and is an image, else None
206
+ """Return artist if within axes, and is an image, else None
162
207
"""
163
208
if not getattr (event , 'inaxes' ):
164
209
return None
165
210
artist = event .inaxes .images [0 ]
166
211
return artist if isinstance (artist , mpl_img .AxesImage ) else None
167
212
168
- def on_click (self , event ):
169
- if event .button == 1 :
170
- self .follow = not self .follow
171
- plt .draw ()
172
-
173
213
def on_scroll (self , event ):
174
214
assert event .button in ('up' , 'down' )
175
215
im = self ._axis_artist (event )
176
216
if im is None :
177
217
return
178
- im .idx += 1 if event .button == 'up' else - 1
179
- im .idx %= im .size
180
- im .set_data (im .get_slice (im .idx ))
181
- ax = im .axes
182
- ax .draw_artist (im )
183
- ax .figure .canvas .blit (ax .bbox )
218
+ idx = (im .idx + (1 if event .button == 'up' else - 1 ))
219
+ idx = max (min (idx , im .size - 1 ), 0 )
220
+ im .set_viewer_slice (idx )
221
+ self ._draw_ims ()
184
222
185
223
def on_mousemove (self , event ):
186
- if not self . follow :
224
+ if event . button != 1 : # only enabled while dragging
187
225
return
188
226
im = self ._axis_artist (event )
189
227
if im is None :
190
228
return
191
- ax = im .axes
192
- imx , imy = im .imx , im .imy
229
+ x_im , y_im = im .x_im , im .y_im
193
230
x , y = np .round ((event .xdata , event .ydata )).astype (int )
194
- imx .set_data (imx .get_slice (x ))
195
- imy .set_data (imy .get_slice (y ))
196
- imx .idx = x
197
- imy .idx = y
198
- for i in imx , imy :
199
- ax = i .axes
200
- ax .draw_artist (i )
231
+ for i , idx in zip ((x_im , y_im ), (x , y )):
232
+ i .set_viewer_slice (idx )
233
+ self ._draw_ims ()
234
+
235
+ def _draw_ims (self ):
236
+ for im in self ._ims :
237
+ ax = im .axes
238
+ ax .draw_artist (im )
239
+ ax .draw_artist (im .x_line )
240
+ ax .draw_artist (im .y_line )
201
241
ax .figure .canvas .blit (ax .bbox )
202
-
203
-
204
- if __name__ == '__main__' :
205
- a = np .sin (np .linspace (0 ,np .pi ,20 ))
206
- b = np .sin (np .linspace (0 ,np .pi * 5 ,20 ))
207
- data = np .outer (a ,b )[..., np .newaxis ]* a
208
- # all slices
209
- OrthoSlicer3D (data ).show ()
210
-
211
- # broken out into three separate figures
212
- f , ax1 = plt .subplots ()
213
- f , ax2 = plt .subplots ()
214
- f , ax3 = plt .subplots ()
215
- OrthoSlicer3D (data , axes = (ax1 , ax2 , ax3 )).show ()
0 commit comments