@@ -15,22 +15,49 @@ class Cluster:
1515 def __init__ (
1616 self ,
1717 datacube ,
18+ r_space_mask ,
1819 ):
1920 """
2021 Args:
21- datacube (py4DSTEM.DataCube): 4D-STEM data
22-
22+ datacube (py4DSTEM.DataCube): 4D-STEM data
23+ r_space_mask (np.ndarray): Mask in real space to apply background thresholding on the similarity array.
2324
2425 """
25-
2626 self .datacube = datacube
27+ self .r_space_mask = r_space_mask
28+ self .similarity = None
29+ self .similarity_raw = None
30+
31+ def bg_thresholding (self , r_space_mask ,):
32+ self .r_space_mask = np .asarray (r_space_mask )
33+
34+ # if similarity is already computed, apply the thresholding
35+ if self .similarity_raw is not None :
36+ self .similarity = self ._apply_bg_mask (self .similarity_raw )
37+
38+ def _apply_bg_mask (self , similarity ):
39+ if self .r_space_mask is None :
40+ return similarity
41+ return similarity * self .r_space_mask [..., None ]
42+
2743
2844 def find_similarity (
2945 self ,
30- mask = None , # by default
46+ q_space_mask = None ,
3147 smooth_sigma = 0 ,
48+ return_similarity = False
3249 ):
33- # Which neighbors to search
50+
51+ """
52+ Args:
53+ q_space_mask : annular boolean q_space_mask to apply on the diffraction patterns
54+ smooth_sigma : sigma for Gaussian smoothing of the diffraction patterns before calculating similarity
55+ return_similarity : if True, return the similarity array
56+ """
57+ if self .r_space_mask is None :
58+ self .set_mask (r_space_mask )
59+
60+ # List of neighbors to search
3461 # (-1,-1) will be equivalent to (1,1)
3562 self .dxy = np .array (
3663 (
@@ -61,8 +88,8 @@ def find_similarity(
6188 if smooth_sigma > 0 :
6289 diff_ref = gaussian_filter (diff_ref ,smooth_sigma )
6390
64- if mask is not None :
65- diff_ref = diff_ref [mask ]
91+ if q_space_mask is not None :
92+ diff_ref = diff_ref [q_space_mask ]
6693
6794 norm_diff_ref = np .sqrt (np .sum (diff_ref * diff_ref ))
6895 # diff_ref_mean = np.mean(diff_ref)
@@ -83,18 +110,23 @@ def find_similarity(
83110 if smooth_sigma > 0 :
84111 diff = gaussian_filter (diff ,smooth_sigma )
85112
86- if mask is not None :
87- diff = diff [mask ]
113+ if q_space_mask is not None :
114+ diff = diff [q_space_mask ]
88115
89- # image self.similarity with normalized corr: cosine self.similarity?
116+ # image self.similarity with normalized cosine correlation
90117 self .similarity [rx , ry , ind ] = (
91118 np .sum (diff * diff_ref )
92119 / np .sqrt (np .sum (diff * diff ))
93120 / norm_diff_ref
94121 )
95122
96- # self.similarity[rx, ry, ind] = np.mean(np.abs(diff - diff_ref)) / diff_ref_mean
123+
124+ self .similarity_raw = self .similarity .copy ()
125+ self .similarity = self ._apply_bg_mask (self .similarity )
97126
127+ if return_similarity :
128+ return self .similarity
129+
98130 # Create a function to map cluster index to color
99131 def get_color (self , cluster_index ):
100132 colors = [
@@ -114,33 +146,27 @@ def get_color(self, cluster_index):
114146 # Find the pixel with the highest self.similarity and start the clustering from there
115147 def indexing_clusters_all (
116148 self ,
117- # mask,
118149 threshold ,
119150 ):
120-
121- # self.dxy = np.array(
122- # (
123- # (-1, -1),
124- # (-1, 0),
125- # (-1, 1),
126- # (0, -1),
127- # (1, 1),
128- # (1, 0),
129- # (1, -1),
130- # (0, 1),
131- # )
132- # )
133-
151+ """
152+ Args:
153+ threshold : threshold for similarity to consider two pixels as part of the same cluster
154+ """
155+
134156 sim_averaged = np .mean (self .similarity , axis = 2 )
135157
158+ # Assigning the background as 'counted'
159+ sim_averaged [~ self .r_space_mask ] = - 1.0
160+
136161 # color the pixels with the cluster index
137- # map_cluster = np.zeros((sim_averaged.shape[0],sim_averaged.shape[1]))
138162 self .cluster_map = - 1 * np .ones (
139163 (sim_averaged .shape [0 ], sim_averaged .shape [1 ]), dtype = np .float64
140164 )
141165 self .cluster_map_rgb = np .zeros (
142166 (sim_averaged .shape [0 ], sim_averaged .shape [1 ], 4 ), dtype = np .float64
143167 )
168+
169+ self .cluster_map_rgb [..., 3 ] = 1.0 #start as opaque black
144170
145171 # store arrays of cluster_indices in a list
146172 self .cluster_list = []
@@ -156,14 +182,18 @@ def indexing_clusters_all(
156182 # finding the pixel that has the highest self.similarity among the pixel that hasn't been clustered yet
157183 # this will be the 'starting pixel' of a new cluster
158184 rx0 , ry0 = np .unravel_index (sim_averaged .argmax (), sim_averaged .shape )
159- # print(rx0, ry0)
185+
186+ # Guarding to check if the seed is background
187+ if self .r_space_mask is not None and not self .r_space_mask [rx0 , ry0 ]:
188+ sim_averaged [rx0 , ry0 ] = - 1 # mark processed so we don't pick it again
189+ continue
190+
160191
161192 cluster_indices = np .empty ((0 , 2 ))
162193 cluster_indices = (np .append (cluster_indices , [[rx0 , ry0 ]], axis = 0 )).astype (
163194 np .int32
164195 )
165196
166- # map_cluster[rx0, ry0] = cluster_count_ind+1
167197 self .cluster_map [rx0 , ry0 ] = cluster_count_ind
168198
169199 color = self .get_color (cluster_count_ind + 1 )
@@ -182,7 +212,7 @@ def indexing_clusters_all(
182212 # counter to check if pixel in the cluster are checked for NN
183213 counting_added_pixel += 1
184214
185- # set to -1 as its NN will be checked
215+ # set to -1 since now its NN will be checked
186216 sim_averaged [rx0 , ry0 ] = - 1
187217
188218 for ind in range (self .dxy .shape [0 ]):
@@ -194,12 +224,13 @@ def indexing_clusters_all(
194224 x_ind < self .similarity .shape [0 ] - 2 and \
195225 y_ind < self .similarity .shape [1 ] - 2 :
196226
227+ r_ok = True if self .r_space_mask is None else bool (self .r_space_mask [x_ind , y_ind ])
228+
197229 # add if the neighbor is similar, but don't add if the neighbor is already in a cluster
198230 if self .similarity [rx0 , ry0 , ind ] >= threshold \
199- and self .cluster_map [x_ind , y_ind ] == - 1 :
231+ and self .cluster_map [x_ind , y_ind ] == - 1 and r_ok :
200232
201- # print(cluster_indices)
202- # print([[x_ind, y_ind]])
233+
203234 cluster_indices = np .append (
204235 cluster_indices , [[x_ind , y_ind ]], axis = 0
205236 )
@@ -217,9 +248,9 @@ def indexing_clusters_all(
217248 if counting_added_pixel == 0 :
218249 break
219250
220- # # single pixel cluster
251+ # single pixel cluster
221252 # if cluster_indices.shape[0] == 1:
222- # self.cluster_map [cluster_indices[0, 0], cluster_indices[0, 1]] = [
253+ # self.cluster_map_rgb [cluster_indices[0, 0], cluster_indices[0, 1]] = [
223254 # 0,
224255 # 0,
225256 # 0,
@@ -229,7 +260,7 @@ def indexing_clusters_all(
229260 self .cluster_list .append (cluster_indices )
230261 cluster_count_ind += 1
231262
232- # return cluster_count_ind, self.cluster_list, map_cluster, sim_averaged
263+ # return cluster_count_ind, self.cluster_list, self.cluster_map, self.cluster_map_rgb
233264
234265 def create_cluster_cube (
235266 self ,
0 commit comments