1- '''
1+ """
22functions to plot 3D-deformation fields and simple 3D-structures
3- '''
3+ """
44
55
66import matplotlib
1111
1212
1313def set_axes_equal (ax ):
14-
15- '''
14+
15+ """
1616 Following https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to
1717 Make axes of 3D plot have equal scale so that spheres appear as spheres,
1818 cubes as cubes, etc.. This is one possible solution to Matplotlib's
@@ -23,7 +23,7 @@ def set_axes_equal(ax):
2323 ax: matplotlib.axes object
2424
2525
26- '''
26+ """
2727
2828 x_limits = ax .get_xlim3d ()
2929 y_limits = ax .get_ylim3d ()
@@ -38,17 +38,17 @@ def set_axes_equal(ax):
3838
3939 # The plot bounding box is a sphere in the sense of the infinity
4040 # norm, hence I call half the max range the plot radius.
41- plot_radius = 0.5 * max ([x_range , y_range , z_range ])
41+ plot_radius = 0.5 * max ([x_range , y_range , z_range ])
4242
4343 ax .set_xlim3d ([x_middle - plot_radius , x_middle + plot_radius ])
4444 ax .set_ylim3d ([y_middle - plot_radius , y_middle + plot_radius ])
4545 ax .set_zlim3d ([z_middle - plot_radius , z_middle + plot_radius ])
4646
4747
4848def scatter_3D (a , cmap = "jet" , sca_args = None , control = "color" , size = 60 ):
49-
49+
5050 # default arguments for the quiver plot. can be overwritten by quiv_args
51- if not isinstance (sca_args ,dict ):
51+ if not isinstance (sca_args , dict ):
5252 sca_args = {}
5353 scatter_args = {"alpha" : 1 }
5454 scatter_args .update (sca_args )
@@ -59,21 +59,30 @@ def scatter_3D(a, cmap="jet", sca_args=None, control="color", size=60):
5959 z = z .flatten ()
6060
6161 fig = plt .figure ()
62- ax = fig .gca (projection = '3d' , rasterized = True )
62+ ax = fig .gca (projection = "3d" , rasterized = True )
6363
6464 if control == "color" :
6565 # make cmap
6666 cbound = [np .nanmin (a ), np .nanmax (a )]
6767 # create normalized color map for arrows
68- norm = matplotlib .colors .Normalize (vmin = cbound [0 ], vmax = cbound [1 ]) # 10 ) #cbound[1] ) #)
68+ norm = matplotlib .colors .Normalize (
69+ vmin = cbound [0 ], vmax = cbound [1 ]
70+ ) # 10 ) #cbound[1] ) #)
6971 sm = matplotlib .cm .ScalarMappable (cmap = cmap , norm = norm )
7072 sm .set_array ([])
7173 # different option
7274 cm = matplotlib .cm .get_cmap (cmap )
7375 colors = cm (norm (a )).reshape (a .shape [0 ] * a .shape [1 ] * a .shape [2 ], 4 ) #
7476 # plotting
7577 nan_filter = ~ np .isnan (a .flatten ())
76- ax .scatter (x [nan_filter ], y [nan_filter ], z [nan_filter ], c = colors [nan_filter ], s = size , ** scatter_args )
78+ ax .scatter (
79+ x [nan_filter ],
80+ y [nan_filter ],
81+ z [nan_filter ],
82+ c = colors [nan_filter ],
83+ s = size ,
84+ ** scatter_args
85+ )
7786 plt .colorbar (sm )
7887
7988 if control == "alpha" :
@@ -88,13 +97,25 @@ def scatter_3D(a, cmap="jet", sca_args=None, control="color", size=60):
8897 ax_scale = plt .axes ([0.88 , 0.1 , 0.05 , 0.7 ])
8998 # ax_scale.set_ylim((0.1,1.2))
9099 nm = 5
91- ax_scale .scatter ([0 ] * nm , np .linspace (a .min (), a .max (), nm ), s = sizes .max () * np .linspace (0 , 1 , nm ))
100+ ax_scale .scatter (
101+ [0 ] * nm ,
102+ np .linspace (a .min (), a .max (), nm ),
103+ s = sizes .max () * np .linspace (0 , 1 , nm ),
104+ )
92105 ax_scale .spines ["left" ].set_visible (False )
93106 ax_scale .spines ["right" ].set_visible (True )
94107 ax_scale .spines ["bottom" ].set_visible (False )
95108 ax_scale .spines ["top" ].set_visible (False )
96- ax_scale .tick_params (axis = "both" , which = "both" , labelbottom = False , labelleft = False , labelright = True ,
97- bottom = False , left = False , right = True )
109+ ax_scale .tick_params (
110+ axis = "both" ,
111+ which = "both" ,
112+ labelbottom = False ,
113+ labelleft = False ,
114+ labelright = True ,
115+ bottom = False ,
116+ left = False ,
117+ right = True ,
118+ )
98119
99120 ax .set_xlim (0 , a .shape [0 ])
100121 ax .set_ylim (0 , a .shape [1 ])
@@ -130,7 +151,9 @@ def plot_3D_alpha(data):
130151
131152 data_fil = data .copy ()
132153 data_fil [(data == np .inf )] = np .nanmax (data [~ (data == np .inf )])
133- data_fil = (data_fil - np .nanmin (data_fil )) / (np .nanmax (data_fil ) - np .nanmin (data_fil ))
154+ data_fil = (data_fil - np .nanmin (data_fil )) / (
155+ np .nanmax (data_fil ) - np .nanmin (data_fil )
156+ )
134157 data_fil [np .isnan (data_fil )] = 0
135158
136159 col [:, :, :, 2 ] = 1
@@ -149,16 +172,31 @@ def plot_3D_alpha(data):
149172 z [:, :, 1 ::2 ] += 0.95
150173
151174 fig = plt .figure ()
152- ax = fig .gca (projection = '3d' )
175+ ax = fig .gca (projection = "3d" )
153176 ax .voxels (x , y , z , fill , facecolors = col_exp , edgecolors = col_exp )
154177 ax .set_xlabel ("x" )
155178 ax .set_ylabel ("y" )
156179 ax .set_zlabel ("z" )
157180 plt .show ()
158181
159182
160- def quiver_3D (u , v , w , x = None , y = None , z = None , mask_filtered = None , filter_def = 0 , filter_reg = (1 , 1 , 1 ),
161- cmap = "jet" , quiv_args = None , vmin = None , vmax = None , arrow_scale = 0.15 , equal_ax = True ):
183+ def quiver_3D (
184+ u ,
185+ v ,
186+ w ,
187+ x = None ,
188+ y = None ,
189+ z = None ,
190+ mask_filtered = None ,
191+ filter_def = 0 ,
192+ filter_reg = (1 , 1 , 1 ),
193+ cmap = "jet" ,
194+ quiv_args = None ,
195+ vmin = None ,
196+ vmax = None ,
197+ arrow_scale = 0.15 ,
198+ equal_ax = True ,
199+ ):
162200 """
163201 Displaying 3D deformation fields vector arrows
164202
@@ -212,15 +250,21 @@ def quiver_3D(u, v, w, x=None, y=None, z=None, mask_filtered=None, filter_def=0,
212250 """
213251
214252 # default arguments for the quiver plot. can be overwritten by quiv_args
215- quiver_args = {"normalize" :False , "alpha" :0.8 , "pivot" :'tail' , "linewidth" :1 , "length" :1 }
253+ quiver_args = {
254+ "normalize" : False ,
255+ "alpha" : 0.8 ,
256+ "pivot" : "tail" ,
257+ "linewidth" : 1 ,
258+ "length" : 1 ,
259+ }
216260 if isinstance (quiv_args , dict ):
217261 quiver_args .update (quiv_args )
218- # overwriting length if an arrow scale and a "length" argument in quiv_args
262+ # overwriting length if an arrow scale and a "length" argument in quiv_args
219263 # is provided at the same
220264 if arrow_scale is not None :
221265 quiver_args ["length" ] = 1
222-
223- # convert filter ot list if proveided as int
266+
267+ # convert filter ot list if proveided as int
224268 if not isinstance (filter_reg , (tuple , list )):
225269 filter_reg = [filter_reg ] * 3
226270
@@ -234,8 +278,9 @@ def quiver_3D(u, v, w, x=None, y=None, z=None, mask_filtered=None, filter_def=0,
234278 x , y , z = np .indices (u .shape )
235279 else :
236280 raise ValueError (
237- "displacement data has wrong number of dimensions (%s). Use 1d array, list, or 3d array." % str (
238- len (u .shape )))
281+ "displacement data has wrong number of dimensions (%s). Use 1d array, list, or 3d array."
282+ % str (len (u .shape ))
283+ )
239284
240285 # conversion to array
241286 x , y , z = np .array ([x , y , z ])
@@ -246,7 +291,7 @@ def quiver_3D(u, v, w, x=None, y=None, z=None, mask_filtered=None, filter_def=0,
246291 if isinstance (filter_reg , list ):
247292 show_only = np .zeros (u .shape ).astype (bool )
248293 # filtering out every x-th
249- show_only [::filter_reg [0 ], ::filter_reg [1 ], ::filter_reg [2 ]] = True
294+ show_only [:: filter_reg [0 ], :: filter_reg [1 ], :: filter_reg [2 ]] = True
250295 mask_filtered = np .logical_and (mask_filtered , show_only )
251296
252297 xf = x [mask_filtered ]
@@ -265,28 +310,32 @@ def quiver_3D(u, v, w, x=None, y=None, z=None, mask_filtered=None, filter_def=0,
265310 # different option
266311 colors = matplotlib .cm .jet (norm (df )) #
267312
268- colors = [c for c , d in zip (colors , df ) if d > 0 ] + list (chain (* [[c , c ] for c , d in zip (colors , df ) if d > 0 ]))
313+ colors = [c for c , d in zip (colors , df ) if d > 0 ] + list (
314+ chain (* [[c , c ] for c , d in zip (colors , df ) if d > 0 ])
315+ )
269316 # colors in ax.quiver 3d is really fucked up/ will probably change with updates:
270317 # requires list with: first len(u) entries define the colors of the shaft, then the next len(u)*2 entries define
271318 # the color ofleft and right arrow head side in alternating order. Try for example:
272319 # colors = ["red" for i in range(len(cf))] + list(chain(*[["blue", "yellow"] for i in range(len(cf))]))
273320 # to see this effect.
274- # BUT WAIT THERS MORE: zeor length arrows are apparently filtered out in the matplolib with out filtering
321+ # BUT WAIT THERS MORE: zeor length arrows are apparently filtered out in the matplolib with out filtering
275322 # the color list appropriately so we have to do this our selfs as well
276323
277324 # scale arrows to axis dimensions:
278325 ax_dims = [(x .min (), x .max ()), (y .min (), y .max ()), (z .min (), z .max ())]
279326 if arrow_scale is not None :
280327 max_length = df .max ()
281- max_dim_length = np .max ([(d [1 ] - d [0 ] + 1 ) for d in ax_dims ] )
328+ max_dim_length = np .max ([(d [1 ] - d [0 ] + 1 ) for d in ax_dims ])
282329 scale = max_dim_length * arrow_scale / max_length
283330 else :
284331 scale = 1
285332
286333 # plotting
287334 fig = plt .figure ()
288- ax = fig .gca (projection = '3d' , rasterized = True )
289- ax .quiver (xf , yf , zf , vf * scale , uf * scale , wf * scale , colors = colors , ** quiver_args )
335+ ax = fig .gca (projection = "3d" , rasterized = True )
336+ ax .quiver (
337+ xf , yf , zf , vf * scale , uf * scale , wf * scale , colors = colors , ** quiver_args
338+ )
290339 plt .colorbar (sm )
291340
292341 ax .set_xlim (ax_dims [0 ])
@@ -302,6 +351,5 @@ def quiver_3D(u, v, w, x=None, y=None, z=None, mask_filtered=None, filter_def=0,
302351 ax .w_xaxis .set_pane_color ((0.2 , 0.2 , 0.2 , 1.0 ))
303352 ax .w_yaxis .set_pane_color ((0.2 , 0.2 , 0.2 , 1.0 ))
304353 ax .w_zaxis .set_pane_color ((0.2 , 0.2 , 0.2 , 1.0 ))
305-
306- return fig
307354
355+ return fig
0 commit comments