@@ -33,8 +33,8 @@ class OrthoSlicer3D(object):
33
33
>>> OrthoSlicer3D(data).show() # doctest: +SKIP
34
34
"""
35
35
# Skip doctest above b/c not all systems have mpl installed
36
- def __init__ (self , data , axes = None , aspect_ratio = (1 , 1 , 1 ), cmap = 'gray' ,
37
- pcnt_range = (1. , 99. ), figsize = (8 , 8 )):
36
+ def __init__ (self , data , axes = None , aspect_ratio = (1 , 1 , 1 ), affine = None ,
37
+ cmap = 'gray' , pcnt_range = (1. , 99. ), figsize = (8 , 8 )):
38
38
"""
39
39
Parameters
40
40
----------
@@ -46,13 +46,14 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
46
46
or None (default).
47
47
aspect_ratio : array-like, optional
48
48
Stretch factors for X, Y, Z directions.
49
+ affine : array-like | None
50
+ Affine transform for the data. This is used to determine
51
+ how the data should be sliced for plotting into the X, Y,
52
+ and Z view axes. If None, identity is assumed.
49
53
cmap : str | instance of cmap, optional
50
- String or cmap instance specifying colormap. Will be passed as
51
- ``cmap`` argument to ``plt.imshow``.
54
+ String or cmap instance specifying colormap.
52
55
pcnt_range : array-like, optional
53
- Percentile range over which to scale image for display. If None,
54
- scale between image mean and max. If sequence, min and max
55
- percentile over which to scale image.
56
+ Percentile range over which to scale image for display.
56
57
figsize : tuple
57
58
Figure size (in inches) to use if axes are None.
58
59
"""
@@ -63,6 +64,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
63
64
data = np .asanyarray (data )
64
65
if data .ndim < 3 :
65
66
raise ValueError ('data must have at least 3 dimensions' )
67
+ affine = np .array (affine , float ) if affine is not None else np .eye (4 )
68
+ if affine .ndim != 2 or affine .shape != (4 , 4 ):
69
+ raise ValueError ('affine must be a 4x4 matrix' )
70
+ self ._affine = affine
66
71
self ._volume_dims = data .shape [3 :]
67
72
self ._current_vol_data = data [:, :, :, 0 ] if data .ndim > 3 else data
68
73
self ._data = data
@@ -116,17 +121,16 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
116
121
labels = dict (z = 'ILSR' , y = 'ALPR' , x = 'AIPS' )
117
122
118
123
# set up axis crosshairs
124
+ self ._crosshairs = dict ()
119
125
for type_ , i_1 , i_2 in zip ('zyx' , 'xxy' , 'yzz' ):
120
- ax = self ._axes [type_ ]
121
- im = self ._ims [type_ ]
122
- label = labels [type_ ]
123
- # add slice lines
124
- im .vert_line = ax .plot ([self ._idx [i_1 ]] * 2 ,
125
- [- 0.5 , self ._sizes [i_2 ] - 0.5 ],
126
- color = colors [i_1 ], linestyle = '-' )[0 ]
127
- im .horiz_line = ax .plot ([- 0.5 , self ._sizes [i_1 ] - 0.5 ],
128
- [self ._idx [i_2 ]] * 2 ,
129
- color = colors [i_2 ], linestyle = '-' )[0 ]
126
+ ax , label = self ._axes [type_ ], labels [type_ ]
127
+ vert = ax .plot ([self ._idx [i_1 ]] * 2 ,
128
+ [- 0.5 , self ._sizes [i_2 ] - 0.5 ],
129
+ color = colors [i_1 ], linestyle = '-' )[0 ]
130
+ horiz = ax .plot ([- 0.5 , self ._sizes [i_1 ] - 0.5 ],
131
+ [self ._idx [i_2 ]] * 2 ,
132
+ color = colors [i_2 ], linestyle = '-' )[0 ]
133
+ self ._crosshairs [type_ ] = dict (vert = vert , horiz = horiz )
130
134
# add text labels (top, right, bottom, left)
131
135
lims = [0 , self ._sizes [i_1 ], 0 , self ._sizes [i_2 ]]
132
136
bump = 0.01
@@ -136,10 +140,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
136
140
[lims [0 ] - bump * lims [1 ], lims [3 ] / 2. ]]
137
141
anchors = [['center' , 'bottom' ], ['left' , 'center' ],
138
142
['center' , 'top' ], ['right' , 'center' ]]
139
- im . texts = [ ax . text ( pos [ 0 ], pos [ 1 ] , lab ,
140
- horizontalalignment = anchor [0 ],
141
- verticalalignment = anchor [1 ])
142
- for pos , anchor , lab in zip ( poss , anchors , label )]
143
+ for pos , anchor , lab in zip ( poss , anchors , label ):
144
+ ax . text ( pos [0 ], pos [ 1 ], lab ,
145
+ horizontalalignment = anchor [0 ],
146
+ verticalalignment = anchor [ 1 ])
143
147
ax .axis (lims )
144
148
ax .set_aspect (aspect_ratio [type_ ])
145
149
ax .patch .set_visible (False )
@@ -172,18 +176,18 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray',
172
176
173
177
# when an index changes, which crosshairs need to be updated
174
178
self ._cross_setters = dict (
175
- x = [self ._ims ['z' ]. vert_line .set_xdata ,
176
- self ._ims ['y' ]. vert_line .set_xdata ],
177
- y = [self ._ims ['z' ]. horiz_line .set_ydata ,
178
- self ._ims ['x' ]. vert_line .set_xdata ],
179
- z = [self ._ims ['y' ]. horiz_line .set_ydata ,
180
- self ._ims ['x' ]. horiz_line .set_ydata ])
179
+ x = [self ._crosshairs ['z' ][ 'vert' ] .set_xdata ,
180
+ self ._crosshairs ['y' ][ 'vert' ] .set_xdata ],
181
+ y = [self ._crosshairs ['z' ][ 'horiz' ] .set_ydata ,
182
+ self ._crosshairs ['x' ][ 'vert' ] .set_xdata ],
183
+ z = [self ._crosshairs ['y' ][ 'horiz' ] .set_ydata ,
184
+ self ._crosshairs ['x' ][ 'horiz' ] .set_ydata ])
181
185
182
186
self ._figs = set ([a .figure for a in self ._axes .values ()])
183
187
for fig in self ._figs :
184
188
fig .canvas .mpl_connect ('scroll_event' , self ._on_scroll )
185
- fig .canvas .mpl_connect ('motion_notify_event' , self ._on_mousemove )
186
- fig .canvas .mpl_connect ('button_press_event' , self ._on_mousemove )
189
+ fig .canvas .mpl_connect ('motion_notify_event' , self ._on_mouse )
190
+ fig .canvas .mpl_connect ('button_press_event' , self ._on_mouse )
187
191
fig .canvas .mpl_connect ('key_press_event' , self ._on_keypress )
188
192
189
193
def show (self ):
@@ -279,6 +283,7 @@ def _in_axis(self, event):
279
283
return key
280
284
281
285
def _on_scroll (self , event ):
286
+ """Handle mpl scroll wheel event"""
282
287
assert event .button in ('up' , 'down' )
283
288
key = self ._in_axis (event )
284
289
if key is None :
@@ -296,7 +301,8 @@ def _on_scroll(self, event):
296
301
self ._update_voxel_levels ()
297
302
self ._draw ()
298
303
299
- def _on_mousemove (self , event ):
304
+ def _on_mouse (self , event ):
305
+ """Handle mpl mouse move and button press events"""
300
306
if event .button != 1 : # only enabled while dragging
301
307
return
302
308
key = self ._in_axis (event )
@@ -312,18 +318,18 @@ def _on_mousemove(self, event):
312
318
self ._draw ()
313
319
314
320
def _on_keypress (self , event ):
321
+ """Handle mpl keypress events"""
315
322
if event .key is not None and 'escape' in event .key :
316
323
self .close ()
317
324
318
325
def _draw (self ):
319
- for im in self ._ims .values ():
320
- ax = im .axes
326
+ """Update all four (or three) plots"""
327
+ for key in 'xyz' :
328
+ ax , im = self ._axes [key ], self ._ims [key ]
321
329
ax .draw_artist (im )
322
- ax . draw_artist ( im . vert_line )
323
- ax .draw_artist (im . horiz_line )
330
+ for line in self . _crosshairs [ key ]. values ():
331
+ ax .draw_artist (line )
324
332
ax .figure .canvas .blit (ax .bbox )
325
- for t in im .texts :
326
- ax .draw_artist (t )
327
333
if self .n_volumes > 1 and 'v' in self ._axes : # user might only pass 3
328
334
ax = self ._axes ['v' ]
329
335
ax .draw_artist (ax .patch ) # axis bgcolor to erase old lines
0 commit comments