1010from mpl_toolkits .mplot3d import Axes3D
1111
1212
13- def scatter_3D (a , cmap = "jet" , sca_args = {} , control = "color" , size = 60 ):
13+ def scatter_3D (a , cmap = "jet" , sca_args = None , control = "color" , size = 60 ):
1414 # default arguments for the quiver plot. can be overwritten by quiv_args
15+ if not isinstance (sca_args ,dict ):
16+ sca_args = {}
1517 scatter_args = {"alpha" : 1 }
1618 scatter_args .update (sca_args )
1719
@@ -25,7 +27,7 @@ def scatter_3D(a, cmap="jet", sca_args={}, control="color", size=60):
2527
2628 if control == "color" :
2729 # make cmap
28- cbound = [0 , np .nanmax (a )]
30+ cbound = [np . nanmin ( a ) , np .nanmax (a )]
2931 # create normalized color map for arrows
3032 norm = matplotlib .colors .Normalize (vmin = cbound [0 ], vmax = cbound [1 ]) # 10 ) #cbound[1] ) #)
3133 sm = matplotlib .cm .ScalarMappable (cmap = cmap , norm = norm )
@@ -39,8 +41,8 @@ def scatter_3D(a, cmap="jet", sca_args={}, control="color", size=60):
3941 plt .colorbar (sm )
4042
4143 if control == "alpha" :
42- # untested### #
43- col = [(0 , 0 , 1 , x / np .max (z )) for x in np .ravel (z )]
44+ # untested #
45+ colors = [(0 , 0 , 1 , x / np .max (z )) for x in np .ravel (z )]
4446 ax .scatter (x , y , z , c = colors , s = size , ** scatter_args )
4547 plt .show ()
4648
@@ -86,6 +88,8 @@ def explode(data):
8688
8789
8890def plot_3D_alpha (data ):
91+ # plotting each voxel as a slightly smaller block with transparency depending
92+ # on the data value
8993 # following "https://matplotlib.org/3.1.1/gallery/mplot3d/voxels_numpy_logo.html"
9094
9195 col = np .zeros ((data .shape [0 ], data .shape [1 ], data .shape [2 ], 4 ))
@@ -119,32 +123,80 @@ def plot_3D_alpha(data):
119123 plt .show ()
120124
121125
122- def quiver_3D (u , v , w , x = None , y = None , z = None , image_dim = None , mask_filtered = None , filter_def = 0 , filter_reg = (1 , 1 , 1 ),
123- cmap = "jet" , quiv_args = {}, cbound = None ):
124- # filter_def filters values with smaler absolute deformation
125- # nans are also removed
126- # setting the filter to <0 will probably mess up the arrow colors
127- # filter_reg filters every n-th value, separate for x, y, z axis
128- # you can also provide your own mask with mask_filtered !!! make sure to filter out arrows with zero total deformation!!!!
129- # other wise the arrows are not colored correctly
130- # use indices for x,y,z axis as default - can be specified by x,y,z
126+ def quiver_3D (u , v , w , x = None , y = None , z = None , mask_filtered = None , filter_def = 0 , filter_reg = (1 , 1 , 1 ),
127+ cmap = "jet" , quiv_args = None , vmin = None , vmax = None ):
128+ """ Displaying 3D deformation fields vector arrows
129+
130+ Parameters
131+ ----------
132+ u,v,w: 3d ndarray or lists
133+ arrays or list with deformation in x,y and z direction
134+
135+ x,y,z: 3d ndarray or lists
136+ Arrays or list with deformation the coordinates of the deformations.
137+ Must match the dimensions of the u,v qnd w. If not provided x,y and z are created
138+ with np.indices(u.shape)
139+
140+ mask_filtered, boolean 3d ndarray or 1d ndarray
141+ Array, or list with same dimensions as the deformations. Defines the area where deformations are drawn
142+ filter_def: float
143+ Filter that prevents the display of deformations arrows with length < filter_def
144+ filter_reg: tuple
145+ Filter that prevents the display of every i-th deformations arrows separatly alon each axis.
146+ filter_reg=(2,2,2) means that only every second arrow along x,y z axis is displayed leading to
147+ a total reduction of displayed arrows by a factor of 8.
148+ cmap: string
149+ matplotlib colorbar that defines the coloring of the arrow
150+ quiv_args: dict
151+ Dictionary with kwargs passed on to the matplotlib quiver function.
152+
153+ vmin,vmax: float
154+ Upper and lower bounds for the colormap. Works like vmin and vmax in plt.imshow().
155+
156+ Returns
157+ -------
158+ fig: matploltib figure object
159+
160+ ax: mattplotlib axes object
161+ the holding the main 3D quiver plot
162+
163+ """
131164
132165 # default arguments for the quiver plot. can be overwritten by quiv_args
133166 quiver_args = {"normalize" : False , "alpha" : 0.8 , "pivot" : 'tail' , "linewidth" : 1 , "length" : 20 }
134- quiver_args .update (quiv_args )
135- if not isinstance (image_dim , (list , tuple , np .ndarray )):
136- image_dim = np .array (u .shape )
167+ if isinstance (quiv_args , dict ):
168+ quiver_args .update (quiv_args )
137169
170+ # generating coordinates if not provided
138171 if x is None :
139- x , y , z = np .indices (u .shape ) * (np .array (image_dim ) / np .array (u .shape ))[:, np .newaxis , np .newaxis , np .newaxis ]
140- else :
141- x , y , z = np .array ([x , y , z ]) * (np .array (image_dim ) / np .array (u .shape ))[:, np .newaxis ]
142-
172+ # if you provide deformations as a list
173+ if len (u .shape ) == 1 :
174+ x , y , z = [np .indices (u .shape )[0 ] for i in range (3 )]
175+ # if you provide deformations as an array
176+ elif len (u .shape ) == 3 :
177+ x , y , z = np .indices (u .shape )
178+ else :
179+ raise ValueError (
180+ "displacement data has wrong number of dimensions (%s). Use 1d array, list, or 3d array." % str (
181+ len (u .shape )))
182+
183+ # conversion to array
184+ x , y , z = np .array ([x , y , z ])
185+
186+ # filtering arrows for the display
143187 deformation = np .sqrt (u ** 2 + v ** 2 + w ** 2 )
144188 if not isinstance (mask_filtered , np .ndarray ):
145189 mask_filtered = deformation > filter_def
146- if isinstance (filter_reg , tuple ):
147- mask_filtered [::filter_reg [0 ], ::filter_reg [1 ], ::filter_reg [2 ]] *= True
190+ if isinstance (filter_reg , list ):
191+ show_only = np .zeros (u .shape ).astype (bool )
192+ if len (filter_reg ) == 1 :
193+ show_only [::filter_reg [0 ]] = True
194+ elif len (filter_reg ) == 3 :
195+ show_only [::filter_reg [0 ], ::filter_reg [1 ], ::filter_reg [2 ]] = True
196+ else :
197+ raise ValueError (
198+ "filter_reg data has wrong length (%s). Use list with length 1 or 3." % str (len (filter_reg .shape )))
199+ mask_filtered = np .logical_and (mask_filtered , show_only )
148200
149201 xf = x [mask_filtered ]
150202 yf = y [mask_filtered ]
@@ -154,11 +206,8 @@ def quiver_3D(u, v, w, x=None, y=None, z=None, image_dim=None, mask_filtered=Non
154206 wf = w [mask_filtered ]
155207 df = deformation [mask_filtered ]
156208
157- # make cmap
158- if not cbound :
159- cbound = [0 , np .nanmax (df )]
160209 # create normalized color map for arrows
161- norm = matplotlib .colors .Normalize (vmin = cbound [ 0 ] , vmax = cbound [ 1 ]) # 10 ) #cbound[1] ) # )
210+ norm = matplotlib .colors .Normalize (vmin = vmin , vmax = vmax )
162211 sm = matplotlib .cm .ScalarMappable (cmap = cmap , norm = norm )
163212 sm .set_array ([])
164213 cm = matplotlib .cm .get_cmap (cmap )
@@ -167,11 +216,11 @@ def quiver_3D(u, v, w, x=None, y=None, z=None, image_dim=None, mask_filtered=Non
167216 colors = [c for c , d in zip (colors , df ) if d > 0 ] + list (chain (* [[c , c ] for c , d in zip (colors , df ) if d > 0 ]))
168217 # colors in ax.quiver 3d is really fucked up/ will probably change with updates:
169218 # requires list with: first len(u) entries define the colors of the shaft, then the next len(u)*2 entries define
170- # the color ofleft and right arrow head side in alternating order. Try for example:
219+ # the color of left and right arrow head side in alternating order. Try for example:
171220 # colors = ["red" for i in range(len(cf))] + list(chain(*[["blue", "yellow"] for i in range(len(cf))]))
172- # to see this effect
173- # BUT WAIT THERS MORE: zeor length arrows are apparently filtered out in the matplolib with out filtering the color list appropriately
174- # so we have to do this our selfs as well
221+ # to see this effect.
222+ # BUT WAIT THERE'S MORE: zero length arrows are apparently filtered out in the matplolib with out
223+ # filtering the color list appropriately so we have to do this ourselfs as well
175224
176225 # plotting
177226 fig = plt .figure ()
@@ -189,4 +238,5 @@ def quiver_3D(u, v, w, x=None, y=None, z=None, image_dim=None, mask_filtered=Non
189238 ax .w_xaxis .set_pane_color ((0.2 , 0.2 , 0.2 , 1.0 ))
190239 ax .w_yaxis .set_pane_color ((0.2 , 0.2 , 0.2 , 1.0 ))
191240 ax .w_zaxis .set_pane_color ((0.2 , 0.2 , 0.2 , 1.0 ))
192- return fig
241+
242+ return fig , ax
0 commit comments