88
99class Cluster :
1010 """
11- Clustering 4D data
11+ Clustering 4D data
1212
1313 """
1414
@@ -20,185 +20,223 @@ def __init__(
2020 Args:
2121 datacube (py4DSTEM.DataCube): 4D-STEM data
2222
23-
23+
2424 """
2525
2626 self .datacube = datacube
2727
28-
2928 def find_similarity (
3029 self ,
31- mask = None , # by default
30+ mask = None , # by default
3231 ):
3332 # Which neighbors to search
3433 # (-1,-1) will be equivalent to (1,1)
35- self .dxy = np .array ((
36- (- 1 ,- 1 ),
37- (- 1 ,0 ),
38- (- 1 ,1 ),
39- (0 ,- 1 ),
40- (1 ,1 ),
41- (1 ,0 ),
42- (1 ,- 1 ),
43- (0 ,1 ),
44- ))
45-
46- # initialize the self.similarity array
47- self .similarity = - 1 * np .ones ((self .datacube .shape [0 ],self .datacube .shape [1 ], self .dxy .shape [0 ]))
48-
34+ self .dxy = np .array (
35+ (
36+ (- 1 , - 1 ),
37+ (- 1 , 0 ),
38+ (- 1 , 1 ),
39+ (0 , - 1 ),
40+ (1 , 1 ),
41+ (1 , 0 ),
42+ (1 , - 1 ),
43+ (0 , 1 ),
44+ )
45+ )
46+
47+ # initialize the self.similarity array
48+ self .similarity = - 1 * np .ones (
49+ (self .datacube .shape [0 ], self .datacube .shape [1 ], self .dxy .shape [0 ])
50+ )
51+
4952 # Loop over probe positions
5053 for rx , ry in tqdmnd (
51- range (self .datacube .shape [0 ]),
54+ range (self .datacube .shape [0 ]),
5255 range (self .datacube .shape [1 ]),
5356 ):
5457 if mask is None :
55- diff_ref = self .datacube [rx ,ry ]
58+ diff_ref = self .datacube [rx , ry ]
5659 else :
5760 diff_ref = self .datacube [rx , ry ][mask ]
58-
61+
5962 # loop over neighbors
6063 for ind in range (self .dxy .shape [0 ]):
61- x_ind = rx + self .dxy [ind ,0 ]
62- y_ind = ry + self .dxy [ind ,1 ]
63- if x_ind >= 0 and \
64- y_ind >= 0 and \
65- x_ind < self .datacube .shape [0 ] and \
66- y_ind < self .datacube .shape [1 ]:
67-
64+ x_ind = rx + self .dxy [ind , 0 ]
65+ y_ind = ry + self .dxy [ind , 1 ]
66+ if (
67+ x_ind >= 0
68+ and y_ind >= 0
69+ and x_ind < self .datacube .shape [0 ]
70+ and y_ind < self .datacube .shape [1 ]
71+ ):
72+
6873 if mask is None :
69- diff = self .datacube [x_ind ,y_ind ]
74+ diff = self .datacube [x_ind , y_ind ]
7075 else :
71- diff = self .datacube [x_ind ,y_ind ][mask ]
76+ diff = self .datacube [x_ind , y_ind ][mask ]
7277
7378 # # image self.similarity with mean abs difference
7479 # self.similarity[rx,ry,ind] = np.mean(
7580 # np.abs(
7681 # diff - diff_ref
7782 # )
7883 # )
79-
84+
8085 # image self.similarity with normalized corr: cosine self.similarity?
81- self .similarity [rx ,ry ,ind ] = np .sum (diff * diff_ref ) \
82- / np .sqrt (np .sum (diff * diff )) \
83- / np .sqrt (np .sum (diff_ref * diff_ref ))
84-
85-
86+ self .similarity [rx , ry , ind ] = (
87+ np .sum (diff * diff_ref )
88+ / np .sqrt (np .sum (diff * diff ))
89+ / np .sqrt (np .sum (diff_ref * diff_ref ))
90+ )
91+
8692 # Create a function to map cluster index to color
87- def get_color (
88- self ,
89- cluster_index ):
90- colors = ['slategray' ,'lightcoral' , 'gold' , 'darkorange' ,
91- 'yellowgreen' , 'lightseagreen' , 'cornflowerblue' , 'royalblue' , 'lightsteelblue' , 'darkseagreen' ]
93+ def get_color (self , cluster_index ):
94+ colors = [
95+ "slategray" ,
96+ "lightcoral" ,
97+ "gold" ,
98+ "darkorange" ,
99+ "yellowgreen" ,
100+ "lightseagreen" ,
101+ "cornflowerblue" ,
102+ "royalblue" ,
103+ "lightsteelblue" ,
104+ "darkseagreen" ,
105+ ]
92106 return colors [(cluster_index - 1 ) % len (colors )]
93-
94- # Find the pixel with the highest self.similarity and start the clustering from there
107+
108+ # Find the pixel with the highest self.similarity and start the clustering from there
95109 def indexing_clusters_all (
96110 self ,
97111 mask ,
98112 threshold ,
99-
100113 ):
101-
102- self .dxy = np .array ((
103- (- 1 ,- 1 ),
104- (- 1 ,0 ),
105- (- 1 ,1 ),
106- (0 ,- 1 ),
107- (1 ,1 ),
108- (1 ,0 ),
109- (1 ,- 1 ),
110- (0 ,1 ),
111- ))
112-
113- sim_averaged = np .mean (self .similarity , axis = 2 )
114-
115- #color the pixels with the cluster index
116- #map_cluster = np.zeros((sim_averaged.shape[0],sim_averaged.shape[1]))
117- self .cluster_map = np .zeros ((sim_averaged .shape [0 ],sim_averaged .shape [1 ],4 ),dtype = np .float64 )
118-
119- #store arrays of cluster_indices in a list
114+
115+ self .dxy = np .array (
116+ (
117+ (- 1 , - 1 ),
118+ (- 1 , 0 ),
119+ (- 1 , 1 ),
120+ (0 , - 1 ),
121+ (1 , 1 ),
122+ (1 , 0 ),
123+ (1 , - 1 ),
124+ (0 , 1 ),
125+ )
126+ )
127+
128+ sim_averaged = np .mean (self .similarity , axis = 2 )
129+
130+ # color the pixels with the cluster index
131+ # map_cluster = np.zeros((sim_averaged.shape[0],sim_averaged.shape[1]))
132+ self .cluster_map = np .zeros (
133+ (sim_averaged .shape [0 ], sim_averaged .shape [1 ], 4 ), dtype = np .float64
134+ )
135+
136+ # store arrays of cluster_indices in a list
120137 self .cluster_list = []
121138
122139 # incides of pixel in a cluster
123- cluster_indices = np .empty ((0 ,2 ))
124-
140+ cluster_indices = np .empty ((0 , 2 ))
141+
125142 # Loop over pixels until no new pixel is found (sim_averaged is set to -1 if it is alreaddy serached for NN)
126143 cluster_count_ind = 0
127-
144+
128145 while np .any (sim_averaged != - 1 ):
129146
130- #finding the pixel that has the highest self.similarity among the pixel that hasn't been clustered yet
131- #this will be the 'starting pixel' of a new cluster
147+ # finding the pixel that has the highest self.similarity among the pixel that hasn't been clustered yet
148+ # this will be the 'starting pixel' of a new cluster
132149 rx0 , ry0 = np .unravel_index (sim_averaged .argmax (), sim_averaged .shape )
133- #print(rx0, ry0)
134-
135- cluster_indices = np .empty ((0 ,2 ))
136- cluster_indices = (np .append (cluster_indices , [[rx0 , ry0 ]], axis = 0 )).astype (np .int32 )
137-
138- #map_cluster[rx0, ry0] = cluster_count_ind+1
139- color = self .get_color (cluster_count_ind + 1 )
140- self .cluster_map [rx0 ,ry0 ] = plt .cm .colors .to_rgba (color )
141-
142- #Clustering: one cluster per while loop(until it breaks)
143- #Marching algorithm: find a new position and search the nearest neighbor
144-
145- while True :
150+ # print(rx0, ry0)
151+
152+ cluster_indices = np .empty ((0 , 2 ))
153+ cluster_indices = (np .append (cluster_indices , [[rx0 , ry0 ]], axis = 0 )).astype (
154+ np .int32
155+ )
156+
157+ # map_cluster[rx0, ry0] = cluster_count_ind+1
158+ color = self .get_color (cluster_count_ind + 1 )
159+ self .cluster_map [rx0 , ry0 ] = plt .cm .colors .to_rgba (color )
160+
161+ # Clustering: one cluster per while loop(until it breaks)
162+ # Marching algorithm: find a new position and search the nearest neighbor
163+
164+ while True :
146165 counting_added_pixel = 0
147166
148167 for rx0 , ry0 in cluster_indices :
149-
168+
150169 if sim_averaged [rx0 , ry0 ] != - 1 :
151-
152- #counter to check if pixel in the cluster are checked for NN
170+
171+ # counter to check if pixel in the cluster are checked for NN
153172 counting_added_pixel += 1
154-
173+
155174 # set to -1 as its NN will be checked
156175 sim_averaged [rx0 , ry0 ] = - 1
157-
176+
158177 for ind in range (self .dxy .shape [0 ]):
159- x_ind = rx0 + self .dxy [ind ,0 ]
160- y_ind = ry0 + self .dxy [ind ,1 ]
161-
178+ x_ind = rx0 + self .dxy [ind , 0 ]
179+ y_ind = ry0 + self .dxy [ind , 1 ]
180+
162181 # add if the neighbor is similar, but don't add if the neighbor is already in a cluster
163- if self .similarity [rx0 ,ry0 ,ind ] > threshold and \
164- np .array_equal (self .cluster_map [x_ind ,y_ind ], [0 ,0 ,0 ,0 ]):
165-
166- cluster_indices = np .append (cluster_indices , [[x_ind ,y_ind ]], axis = 0 )
167- #self.cluster_map[x_ind, y_ind] = cluster_count_ind+1
168- color = self .get_color (cluster_count_ind + 1 )
169- self .cluster_map [x_ind ,y_ind ] = plt .cm .colors .to_rgba (color )
170-
171- #if no new pixel is checked for NN then break
182+ if self .similarity [
183+ rx0 , ry0 , ind
184+ ] > threshold and np .array_equal (
185+ self .cluster_map [x_ind , y_ind ], [0 , 0 , 0 , 0 ]
186+ ):
187+
188+ cluster_indices = np .append (
189+ cluster_indices , [[x_ind , y_ind ]], axis = 0
190+ )
191+ # self.cluster_map[x_ind, y_ind] = cluster_count_ind+1
192+ color = self .get_color (cluster_count_ind + 1 )
193+ self .cluster_map [x_ind , y_ind ] = plt .cm .colors .to_rgba (
194+ color
195+ )
196+
197+ # if no new pixel is checked for NN then break
172198 if counting_added_pixel == 0 :
173- break
174-
175- #single pixel cluster
199+ break
200+
201+ # single pixel cluster
176202 if cluster_indices .shape [0 ] == 1 :
177- self .cluster_map [cluster_indices [0 ,0 ],cluster_indices [0 ,1 ]] = [0 ,0 ,0 ,1 ]
178-
179- self .cluster_list .append (cluster_indices )
203+ self .cluster_map [cluster_indices [0 , 0 ], cluster_indices [0 , 1 ]] = [
204+ 0 ,
205+ 0 ,
206+ 0 ,
207+ 1 ,
208+ ]
209+
210+ self .cluster_list .append (cluster_indices )
180211 cluster_count_ind += 1
181-
182- #return cluster_count_ind, self.cluster_list, map_cluster, sim_averaged
212+
213+ # return cluster_count_ind, self.cluster_list, map_cluster, sim_averaged
183214
184215 def create_cluster_cube (
185- self ,
186-
187- min_cluster_size ,
188- return_cluster_datacube = False ,
189- ):
190-
191- self .filtered_cluster_list = [arr for arr in self .cluster_list if arr .shape [0 ] >= min_cluster_size ]
192-
193- # datacube [i,j,k,l] where i is the index of the cluster, and j is a place holder, and k,l are the average diffraction pattern of the
194- self .cluster_cube = np .empty ([len (self .filtered_cluster_list ),1 ,self .datacube .shape [2 ], self .datacube .shape [3 ]])
195-
196- for i in tqdmnd (
197- range (len (self .filtered_cluster_list ))
198- ):
199- self .cluster_cube [i ,0 ] = self .datacube [np .array (self .filtered_cluster_list [i ])[:, 0 ], np .array (self .filtered_cluster_list [i ])[:, 1 ]].mean (axis = 0 )
216+ self ,
217+ min_cluster_size ,
218+ return_cluster_datacube = False ,
219+ ):
200220
201- if return_cluster_datacube :
221+ self .filtered_cluster_list = [
222+ arr for arr in self .cluster_list if arr .shape [0 ] >= min_cluster_size
223+ ]
224+
225+ # datacube [i,j,k,l] where i is the index of the cluster, and j is a place holder, and k,l are the average diffraction pattern of the
226+ self .cluster_cube = np .empty (
227+ [
228+ len (self .filtered_cluster_list ),
229+ 1 ,
230+ self .datacube .shape [2 ],
231+ self .datacube .shape [3 ],
232+ ]
233+ )
234+
235+ for i in tqdmnd (range (len (self .filtered_cluster_list ))):
236+ self .cluster_cube [i , 0 ] = self .datacube [
237+ np .array (self .filtered_cluster_list [i ])[:, 0 ],
238+ np .array (self .filtered_cluster_list [i ])[:, 1 ],
239+ ].mean (axis = 0 )
240+
241+ if return_cluster_datacube :
202242 return self .cluster_cube , self .filtered_cluster_list
203-
204-
0 commit comments