@@ -2056,9 +2056,16 @@ def _constraints(
20562056
20572057 obj_6D = copy_to_device (self .object_6D , device = "cpu" )
20582058
2059- obj_6D = gaussian_filter (
2060- obj_6D , real_space_gaussian_filter , axes = (0 , 1 , 2 )
2061- ) # axes only supported in cpu
2059+ # obj_6D = gaussian_filter(
2060+ # obj_6D, real_space_gaussian_filter, axes=(0, 1, 2)
2061+ # ) # axes only supported in cpu
2062+ for i in range (obj_6D .shape [3 ]):
2063+ for j in range (obj_6D .shape [4 ]):
2064+ for k in range (obj_6D .shape [5 ]):
2065+ obj_6D [:, :, :, i , j , k ] = gaussian_filter (
2066+ obj_6D [:, :, :, i , j , k ],
2067+ sigma = real_space_gaussian_filter
2068+ )
20622069
20632070 self ._object = copy_to_device (
20642071 obj_6D .reshape ((s [0 ], s [1 ] * s [2 ], s [3 ] * s [4 ] * s [5 ])), device = storage
@@ -2255,7 +2262,7 @@ def widget(
22552262 ** kwargs ,
22562263 ):
22572264 """ """
2258- from ipywidgets import HBox , VBox , widgets , interact , Dropdown , Label , Layout
2265+ from ipywidgets import HBox , VBox , widgets , interact , Dropdown , Label , Layout , widgets
22592266 from skimage .feature import peak_local_max
22602267 from scipy .ndimage import gaussian_filter
22612268 from py4DSTEM .visualize import return_scaled_histogram_ordering
@@ -2293,6 +2300,30 @@ def widget(
22932300 vmax = vmax ,
22942301 )
22952302
2303+ # Buttons to set view angles
2304+ button_xy = widgets .Button (description = "XY view" , layout = Layout (width = "100px" ))
2305+ button_xz = widgets .Button (description = "XZ view" , layout = Layout (width = "100px" ))
2306+ button_yz = widgets .Button (description = "YZ view" , layout = Layout (width = "100px" ))
2307+
2308+ def on_click_xy (b ):
2309+ ax2 .view_init (elev = 90 , azim = - 90 ) # Top-down
2310+ fig .canvas .draw_idle ()
2311+
2312+ def on_click_xz (b ):
2313+ ax2 .view_init (elev = 0 , azim = - 90 ) # Front-on
2314+ fig .canvas .draw_idle ()
2315+
2316+ def on_click_yz (b ):
2317+ ax2 .view_init (elev = 0 , azim = 0 ) # Side-on
2318+ fig .canvas .draw_idle ()
2319+
2320+ button_xy .on_click (on_click_xy )
2321+ button_xz .on_click (on_click_xz )
2322+ button_yz .on_click (on_click_yz )
2323+
2324+ view_buttons = widgets .HBox ([button_xy , button_xz , button_yz ])
2325+
2326+
22962327 # %matplotlib ipympl
22972328
22982329 with plt .ioff ():
@@ -2301,6 +2332,10 @@ def widget(
23012332 ax1 = fig .add_subplot (1 , 3 , 2 )
23022333 ax2 = fig .add_subplot (1 , 3 , 3 , projection = "3d" )
23032334
2335+ from mpl_toolkits .mplot3d import Axes3D
2336+ from mpl_toolkits .mplot3d import axes3d
2337+ ax2 .set_proj_type ('ortho' )
2338+
23042339 x = obj_6D .shape [0 ] // 2
23052340 y = obj_6D .shape [1 ] // 2
23062341 z = obj_6D .shape [2 ] // 2
@@ -2370,6 +2405,7 @@ def widget(
23702405 ax2 .set_xlim ([0 , obj_6D .shape [3 ]])
23712406 ax2 .set_ylim ([0 , obj_6D .shape [4 ]])
23722407 ax2 .set_zlim ([0 , obj_6D .shape [5 ]])
2408+ set_axes_equal (ax2 )
23732409
23742410 plt .tight_layout ()
23752411
@@ -2474,6 +2510,8 @@ def update_images(
24742510 ax1 .set_title ("xz" )
24752511 ax2 .set_title ("Diffraction" )
24762512
2513+ set_axes_equal (ax2 )
2514+
24772515 plt .tight_layout ()
24782516
24792517 fig .canvas .draw_idle ()
@@ -2580,16 +2618,51 @@ def update_images(
25802618 fig .canvas .layout .height = "400px"
25812619 fig .canvas .toolbar_position = "bottom"
25822620
2583- widget = widgets .VBox (
2584- [
2585- fig .canvas ,
2586- HBox ([x , y ]),
2587- HBox ([z , gaussian_filter_diffraction ]),
2588- HBox ([minimum_threshold , scale_intensities ]),
2589- HBox ([intensities_power , block_center ]),
2590- ],
2591- )
2621+ widget = widgets .VBox ([
2622+ fig .canvas ,
2623+ view_buttons ,
2624+ HBox ([x , y ]),
2625+ HBox ([z , gaussian_filter_diffraction ]),
2626+ HBox ([minimum_threshold , scale_intensities ]),
2627+ HBox ([intensities_power , block_center ]),
2628+ ])
2629+
2630+ # widget = widgets.VBox(
2631+ # [
2632+ # fig.canvas,
2633+ # HBox([x, y]),
2634+ # HBox([z, gaussian_filter_diffraction]),
2635+ # HBox([minimum_threshold, scale_intensities]),
2636+ # HBox([intensities_power, block_center]),
2637+ # ],
2638+ # )
25922639
25932640 display (widget )
25942641
25952642 return self
2643+
2644+ def set_axes_equal (ax ):
2645+ """Set 3D plot axes to equal scale (for matplotlib >= 3.3)."""
2646+ xlim = ax .get_xlim3d ()
2647+ ylim = ax .get_ylim3d ()
2648+ zlim = ax .get_zlim3d ()
2649+
2650+ # Calculate ranges and midpoints
2651+ x_range , x_middle = np .ptp (xlim ), np .mean (xlim )
2652+ y_range , y_middle = np .ptp (ylim ), np .mean (ylim )
2653+ z_range , z_middle = np .ptp (zlim ), np .mean (zlim )
2654+
2655+ # Set max range
2656+ plot_radius = 0.5 * max ([x_range , y_range , z_range ])
2657+
2658+ ax .set_xlim3d ([x_middle - plot_radius , x_middle + plot_radius ])
2659+ ax .set_ylim3d ([y_middle - plot_radius , y_middle + plot_radius ])
2660+ ax .set_zlim3d ([z_middle - plot_radius , z_middle + plot_radius ])
2661+
2662+ # Force equal aspect
2663+ try :
2664+ ax .set_box_aspect ([1 , 1 , 1 ]) # Requires matplotlib >= 3.3
2665+ except AttributeError :
2666+ pass
2667+
2668+
0 commit comments