88
99class Cluster :
1010 """
11- Clustering 4D data
12-
11+ Class for clustering data in 4D-STEM DataCube based on
12+ similarity of neighboring diffraction patterns.
1313 """
1414
1515 def __init__ (
1616 self ,
1717 datacube ,
18- r_space_mask ,
18+ r_space_mask = None ,
1919 ):
2020 """
21- Args:
22- datacube (py4DSTEM.DataCube): 4D-STEM data
23- r_space_mask (np.ndarray): Mask in real space to apply background thresholding on the similarity array.
24-
21+ Parameters
22+ ----------
23+ datacube: DataCube
24+ 4D-STEM data
25+ r_space_mask: np.ndarray
26+ Mask in real space to apply background thresholding on the similarity array.
2527 """
2628 self .datacube = datacube
2729 self .r_space_mask = r_space_mask
2830 self .similarity = None
2931 self .similarity_raw = None
3032
31- def bg_thresholding (self , r_space_mask ,):
32- self .r_space_mask = np .asarray (r_space_mask )
33-
34- # if similarity is already computed, apply the thresholding
35- if self .similarity_raw is not None :
36- self .similarity = self ._apply_bg_mask (self .similarity_raw )
37-
3833 def _apply_bg_mask (self , similarity ):
3934 if self .r_space_mask is None :
4035 return similarity
4136 return similarity * self .r_space_mask [..., None ]
4237
43-
4438 def find_similarity (
45- self ,
46- q_space_mask = None ,
47- smooth_sigma = 0 ,
48- return_similarity = False
39+ self , q_space_mask = None , smooth_sigma = 0 , return_similarity = False
4940 ):
50-
5141 """
52- Args:
53- q_space_mask : annular boolean q_space_mask to apply on the diffraction patterns
54- smooth_sigma : sigma for Gaussian smoothing of the diffraction patterns before calculating similarity
55- return_similarity : if True, return the similarity array
42+ Find similarity to neighboring pixels
43+
44+ Parameters
45+ ----------
46+ q_space_mask : np.ndarray, optional
47+ boolean q_space_mask to apply on the diffraction patterns
48+ smooth_sigma : float, optional
49+ sigma for Gaussian smoothing of the diffraction patterns
50+ before calculating similarity
51+ return_similarity : bool, optinal
52+ if True, return the similarity array
53+
54+ Returns
55+ --------
56+ similarity: np.ndarray
57+ similarity scores for each pixel
5658 """
5759 if self .r_space_mask is None :
5860 self .set_mask (r_space_mask )
59-
61+
6062 # List of neighbors to search
6163 # (-1,-1) will be equivalent to (1,1)
6264 self .dxy = np .array (
@@ -82,12 +84,12 @@ def find_similarity(
8284 range (self .datacube .shape [0 ]),
8385 range (self .datacube .shape [1 ]),
8486 ):
85- diff_ref = self .datacube [rx , ry ].copy ().astype (' float' )
87+ diff_ref = self .datacube [rx , ry ].copy ().astype (" float" )
8688 diff_ref -= diff_ref .mean ()
8789
8890 if smooth_sigma > 0 :
89- diff_ref = gaussian_filter (diff_ref ,smooth_sigma )
90-
91+ diff_ref = gaussian_filter (diff_ref , smooth_sigma )
92+
9193 if q_space_mask is not None :
9294 diff_ref = diff_ref [q_space_mask ]
9395
@@ -104,29 +106,28 @@ def find_similarity(
104106 and x_ind < self .datacube .shape [0 ]
105107 and y_ind < self .datacube .shape [1 ]
106108 ):
107- diff = self .datacube [x_ind , y_ind ].copy ().astype (' float' )
109+ diff = self .datacube [x_ind , y_ind ].copy ().astype (" float" )
108110 diff -= diff .mean ()
109111
110112 if smooth_sigma > 0 :
111- diff = gaussian_filter (diff ,smooth_sigma )
112-
113+ diff = gaussian_filter (diff , smooth_sigma )
114+
113115 if q_space_mask is not None :
114116 diff = diff [q_space_mask ]
115-
117+
116118 # image self.similarity with normalized cosine correlation
117119 self .similarity [rx , ry , ind ] = (
118120 np .sum (diff * diff_ref )
119121 / np .sqrt (np .sum (diff * diff ))
120122 / norm_diff_ref
121123 )
122124
123-
124125 self .similarity_raw = self .similarity .copy ()
125126 self .similarity = self ._apply_bg_mask (self .similarity )
126127
127128 if return_similarity :
128129 return self .similarity
129-
130+
130131 # Create a function to map cluster index to color
131132 def get_color (self , cluster_index ):
132133 colors = [
@@ -149,14 +150,19 @@ def indexing_clusters_all(
149150 threshold ,
150151 ):
151152 """
152- Args:
153- threshold : threshold for similarity to consider two pixels as part of the same cluster
153+ Index all pixsl in a cluster
154+
155+ Parameters
156+ ----------
157+ threshold: float
158+ similarity score threshold to consider pixels as part
159+ of the same cluster
154160 """
155-
161+
156162 sim_averaged = np .mean (self .similarity , axis = 2 )
157163
158- # Assigning the background as 'counted'
159- sim_averaged [~ self .r_space_mask ] = - 1.0
164+ # Assigning the background as 'counted'
165+ sim_averaged [~ self .r_space_mask ] = - 1.0
160166
161167 # color the pixels with the cluster index
162168 self .cluster_map = - 1 * np .ones (
@@ -165,8 +171,8 @@ def indexing_clusters_all(
165171 self .cluster_map_rgb = np .zeros (
166172 (sim_averaged .shape [0 ], sim_averaged .shape [1 ], 4 ), dtype = np .float64
167173 )
168-
169- self .cluster_map_rgb [..., 3 ] = 1.0 # start as opaque black
174+
175+ self .cluster_map_rgb [..., 3 ] = 1.0 # start as opaque black
170176
171177 # store arrays of cluster_indices in a list
172178 self .cluster_list = []
@@ -182,13 +188,12 @@ def indexing_clusters_all(
182188 # finding the pixel that has the highest self.similarity among the pixel that hasn't been clustered yet
183189 # this will be the 'starting pixel' of a new cluster
184190 rx0 , ry0 = np .unravel_index (sim_averaged .argmax (), sim_averaged .shape )
185-
191+
186192 # Guarding to check if the seed is background
187193 if self .r_space_mask is not None and not self .r_space_mask [rx0 , ry0 ]:
188194 sim_averaged [rx0 , ry0 ] = - 1 # mark processed so we don't pick it again
189195 continue
190196
191-
192197 cluster_indices = np .empty ((0 , 2 ))
193198 cluster_indices = (np .append (cluster_indices , [[rx0 , ry0 ]], axis = 0 )).astype (
194199 np .int32
@@ -219,54 +224,69 @@ def indexing_clusters_all(
219224 x_ind = rx0 + self .dxy [ind , 0 ]
220225 y_ind = ry0 + self .dxy [ind , 1 ]
221226
222- if x_ind > 1 and \
223- y_ind > 1 and \
224- x_ind < self .similarity .shape [0 ] - 2 and \
225- y_ind < self .similarity .shape [1 ] - 2 :
227+ if (
228+ x_ind > 1
229+ and y_ind > 1
230+ and x_ind < self .similarity .shape [0 ] - 2
231+ and y_ind < self .similarity .shape [1 ] - 2
232+ ):
226233
227- r_ok = True if self .r_space_mask is None else bool (self .r_space_mask [x_ind , y_ind ])
234+ r_ok = (
235+ True
236+ if self .r_space_mask is None
237+ else bool (self .r_space_mask [x_ind , y_ind ])
238+ )
228239
229240 # add if the neighbor is similar, but don't add if the neighbor is already in a cluster
230- if self .similarity [rx0 , ry0 , ind ] >= threshold \
231- and self .cluster_map [x_ind , y_ind ] == - 1 and r_ok :
241+ if (
242+ self .similarity [rx0 , ry0 , ind ] >= threshold
243+ and self .cluster_map [x_ind , y_ind ] == - 1
244+ and r_ok
245+ ):
232246
233-
234247 cluster_indices = np .append (
235248 cluster_indices , [[x_ind , y_ind ]], axis = 0
236249 )
237250
238251 self .cluster_map [x_ind , y_ind ] = cluster_count_ind
239-
240252
241- # self.cluster_map[x_ind, y_ind] = cluster_count_ind+1
242253 color = self .get_color (cluster_count_ind + 1 )
243- self .cluster_map_rgb [x_ind , y_ind ] = plt . cm . colors . to_rgba (
244- color
254+ self .cluster_map_rgb [x_ind , y_ind ] = (
255+ plt . cm . colors . to_rgba ( color )
245256 )
246257
247258 # if no new pixel is checked for NN then break
248259 if counting_added_pixel == 0 :
249260 break
250261
251- # single pixel cluster
252- # if cluster_indices.shape[0] == 1:
253- # self.cluster_map_rgb[cluster_indices[0, 0], cluster_indices[0, 1]] = [
254- # 0,
255- # 0,
256- # 0,
257- # 1,
258- # ]
259-
260262 self .cluster_list .append (cluster_indices )
261263 cluster_count_ind += 1
262264
263- # return cluster_count_ind, self.cluster_list, self.cluster_map, self.cluster_map_rgb
264-
265265 def create_cluster_cube (
266266 self ,
267267 min_cluster_size ,
268268 return_cluster_datacube = False ,
269269 ):
270+ """
271+ Create dataset (N, 1, qx, qy), where N is the number of clusters
272+ that contains diffraction patterns that are averaged across pixels
273+ in each cluster
274+
275+ Parameters
276+ ----------
277+ min_cluster_size: int
278+ minimum size for a clsuter to be included in dataset
279+ return_cluster_datacube: bool
280+ if True, returns clustered dataset and list of indicies
281+ of clusters
282+
283+ Returns
284+ --------
285+ cluster_cube: np.ndarray
286+ dataset with clsutered diffraction patterns
287+ filtered_cluster_list: list
288+ list of indicies in real space of each pixel of each cluster
289+ """
270290
271291 self .filtered_cluster_list = [
272292 arr for arr in self .cluster_list if arr .shape [0 ] >= min_cluster_size
0 commit comments