Skip to content

Commit 06e5ef9

Browse files
committed
black formatting
1 parent 31401f9 commit 06e5ef9

File tree

2 files changed

+163
-125
lines changed

2 files changed

+163
-125
lines changed

py4DSTEM/process/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
from py4DSTEM.process import diffraction
99
from py4DSTEM.process import wholepatternfit
1010

11-
from py4DSTEM.process.utils import Cluster
11+
from py4DSTEM.process.utils import Cluster

py4DSTEM/process/utils/cluster.py

Lines changed: 162 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class Cluster:
1010
"""
11-
Clustering 4D data
11+
Clustering 4D data
1212
1313
"""
1414

@@ -20,185 +20,223 @@ def __init__(
2020
Args:
2121
datacube (py4DSTEM.DataCube): 4D-STEM data
2222
23-
23+
2424
"""
2525

2626
self.datacube = datacube
2727

28-
2928
def find_similarity(
3029
self,
31-
mask = None, # by default
30+
mask=None, # by default
3231
):
3332
# Which neighbors to search
3433
# (-1,-1) will be equivalent to (1,1)
35-
self.dxy = np.array((
36-
(-1,-1),
37-
(-1,0),
38-
(-1,1),
39-
(0,-1),
40-
(1,1),
41-
(1,0),
42-
(1,-1),
43-
(0,1),
44-
))
45-
46-
# initialize the self.similarity array
47-
self.similarity = -1*np.ones((self.datacube.shape[0],self.datacube.shape[1], self.dxy.shape[0]))
48-
34+
self.dxy = np.array(
35+
(
36+
(-1, -1),
37+
(-1, 0),
38+
(-1, 1),
39+
(0, -1),
40+
(1, 1),
41+
(1, 0),
42+
(1, -1),
43+
(0, 1),
44+
)
45+
)
46+
47+
# initialize the self.similarity array
48+
self.similarity = -1 * np.ones(
49+
(self.datacube.shape[0], self.datacube.shape[1], self.dxy.shape[0])
50+
)
51+
4952
# Loop over probe positions
5053
for rx, ry in tqdmnd(
51-
range(self.datacube.shape[0]),
54+
range(self.datacube.shape[0]),
5255
range(self.datacube.shape[1]),
5356
):
5457
if mask is None:
55-
diff_ref = self.datacube[rx,ry]
58+
diff_ref = self.datacube[rx, ry]
5659
else:
5760
diff_ref = self.datacube[rx, ry][mask]
58-
61+
5962
# loop over neighbors
6063
for ind in range(self.dxy.shape[0]):
61-
x_ind = rx+self.dxy[ind,0]
62-
y_ind = ry+self.dxy[ind,1]
63-
if x_ind >= 0 and \
64-
y_ind >= 0 and \
65-
x_ind < self.datacube.shape[0] and \
66-
y_ind < self.datacube.shape[1]:
67-
64+
x_ind = rx + self.dxy[ind, 0]
65+
y_ind = ry + self.dxy[ind, 1]
66+
if (
67+
x_ind >= 0
68+
and y_ind >= 0
69+
and x_ind < self.datacube.shape[0]
70+
and y_ind < self.datacube.shape[1]
71+
):
72+
6873
if mask is None:
69-
diff = self.datacube[x_ind,y_ind]
74+
diff = self.datacube[x_ind, y_ind]
7075
else:
71-
diff = self.datacube[x_ind,y_ind][mask]
76+
diff = self.datacube[x_ind, y_ind][mask]
7277

7378
# # image self.similarity with mean abs difference
7479
# self.similarity[rx,ry,ind] = np.mean(
7580
# np.abs(
7681
# diff - diff_ref
7782
# )
7883
# )
79-
84+
8085
# image self.similarity with normalized corr: cosine self.similarity?
81-
self.similarity[rx,ry,ind] = np.sum(diff*diff_ref) \
82-
/ np.sqrt(np.sum(diff*diff)) \
83-
/ np.sqrt(np.sum(diff_ref*diff_ref))
84-
85-
86+
self.similarity[rx, ry, ind] = (
87+
np.sum(diff * diff_ref)
88+
/ np.sqrt(np.sum(diff * diff))
89+
/ np.sqrt(np.sum(diff_ref * diff_ref))
90+
)
91+
8692
# Create a function to map cluster index to color
87-
def get_color(
88-
self,
89-
cluster_index):
90-
colors = ['slategray','lightcoral', 'gold', 'darkorange',
91-
'yellowgreen', 'lightseagreen', 'cornflowerblue', 'royalblue', 'lightsteelblue', 'darkseagreen']
93+
def get_color(self, cluster_index):
94+
colors = [
95+
"slategray",
96+
"lightcoral",
97+
"gold",
98+
"darkorange",
99+
"yellowgreen",
100+
"lightseagreen",
101+
"cornflowerblue",
102+
"royalblue",
103+
"lightsteelblue",
104+
"darkseagreen",
105+
]
92106
return colors[(cluster_index - 1) % len(colors)]
93-
94-
# Find the pixel with the highest self.similarity and start the clustering from there
107+
108+
# Find the pixel with the highest self.similarity and start the clustering from there
95109
def indexing_clusters_all(
96110
self,
97111
mask,
98112
threshold,
99-
100113
):
101-
102-
self.dxy = np.array((
103-
(-1,-1),
104-
(-1,0),
105-
(-1,1),
106-
(0,-1),
107-
(1,1),
108-
(1,0),
109-
(1,-1),
110-
(0,1),
111-
))
112-
113-
sim_averaged = np.mean(self.similarity, axis = 2)
114-
115-
#color the pixels with the cluster index
116-
#map_cluster = np.zeros((sim_averaged.shape[0],sim_averaged.shape[1]))
117-
self.cluster_map = np.zeros((sim_averaged.shape[0],sim_averaged.shape[1],4),dtype=np.float64)
118-
119-
#store arrays of cluster_indices in a list
114+
115+
self.dxy = np.array(
116+
(
117+
(-1, -1),
118+
(-1, 0),
119+
(-1, 1),
120+
(0, -1),
121+
(1, 1),
122+
(1, 0),
123+
(1, -1),
124+
(0, 1),
125+
)
126+
)
127+
128+
sim_averaged = np.mean(self.similarity, axis=2)
129+
130+
# color the pixels with the cluster index
131+
# map_cluster = np.zeros((sim_averaged.shape[0],sim_averaged.shape[1]))
132+
self.cluster_map = np.zeros(
133+
(sim_averaged.shape[0], sim_averaged.shape[1], 4), dtype=np.float64
134+
)
135+
136+
# store arrays of cluster_indices in a list
120137
self.cluster_list = []
121138

122139
# incides of pixel in a cluster
123-
cluster_indices = np.empty((0,2))
124-
140+
cluster_indices = np.empty((0, 2))
141+
125142
# Loop over pixels until no new pixel is found (sim_averaged is set to -1 if it is alreaddy serached for NN)
126143
cluster_count_ind = 0
127-
144+
128145
while np.any(sim_averaged != -1):
129146

130-
#finding the pixel that has the highest self.similarity among the pixel that hasn't been clustered yet
131-
#this will be the 'starting pixel' of a new cluster
147+
# finding the pixel that has the highest self.similarity among the pixel that hasn't been clustered yet
148+
# this will be the 'starting pixel' of a new cluster
132149
rx0, ry0 = np.unravel_index(sim_averaged.argmax(), sim_averaged.shape)
133-
#print(rx0, ry0)
134-
135-
cluster_indices = np.empty((0,2))
136-
cluster_indices = (np.append(cluster_indices, [[rx0, ry0]], axis=0)).astype(np.int32)
137-
138-
#map_cluster[rx0, ry0] = cluster_count_ind+1
139-
color = self.get_color(cluster_count_ind+1)
140-
self.cluster_map[rx0,ry0] = plt.cm.colors.to_rgba(color)
141-
142-
#Clustering: one cluster per while loop(until it breaks)
143-
#Marching algorithm: find a new position and search the nearest neighbor
144-
145-
while True:
150+
# print(rx0, ry0)
151+
152+
cluster_indices = np.empty((0, 2))
153+
cluster_indices = (np.append(cluster_indices, [[rx0, ry0]], axis=0)).astype(
154+
np.int32
155+
)
156+
157+
# map_cluster[rx0, ry0] = cluster_count_ind+1
158+
color = self.get_color(cluster_count_ind + 1)
159+
self.cluster_map[rx0, ry0] = plt.cm.colors.to_rgba(color)
160+
161+
# Clustering: one cluster per while loop(until it breaks)
162+
# Marching algorithm: find a new position and search the nearest neighbor
163+
164+
while True:
146165
counting_added_pixel = 0
147166

148167
for rx0, ry0 in cluster_indices:
149-
168+
150169
if sim_averaged[rx0, ry0] != -1:
151-
152-
#counter to check if pixel in the cluster are checked for NN
170+
171+
# counter to check if pixel in the cluster are checked for NN
153172
counting_added_pixel += 1
154-
173+
155174
# set to -1 as its NN will be checked
156175
sim_averaged[rx0, ry0] = -1
157-
176+
158177
for ind in range(self.dxy.shape[0]):
159-
x_ind = rx0+self.dxy[ind,0]
160-
y_ind = ry0+self.dxy[ind,1]
161-
178+
x_ind = rx0 + self.dxy[ind, 0]
179+
y_ind = ry0 + self.dxy[ind, 1]
180+
162181
# add if the neighbor is similar, but don't add if the neighbor is already in a cluster
163-
if self.similarity[rx0,ry0,ind] > threshold and \
164-
np.array_equal(self.cluster_map[x_ind,y_ind], [0,0,0,0]):
165-
166-
cluster_indices = np.append(cluster_indices, [[x_ind,y_ind]], axis=0)
167-
#self.cluster_map[x_ind, y_ind] = cluster_count_ind+1
168-
color = self.get_color(cluster_count_ind+1)
169-
self.cluster_map[x_ind,y_ind] = plt.cm.colors.to_rgba(color)
170-
171-
#if no new pixel is checked for NN then break
182+
if self.similarity[
183+
rx0, ry0, ind
184+
] > threshold and np.array_equal(
185+
self.cluster_map[x_ind, y_ind], [0, 0, 0, 0]
186+
):
187+
188+
cluster_indices = np.append(
189+
cluster_indices, [[x_ind, y_ind]], axis=0
190+
)
191+
# self.cluster_map[x_ind, y_ind] = cluster_count_ind+1
192+
color = self.get_color(cluster_count_ind + 1)
193+
self.cluster_map[x_ind, y_ind] = plt.cm.colors.to_rgba(
194+
color
195+
)
196+
197+
# if no new pixel is checked for NN then break
172198
if counting_added_pixel == 0:
173-
break
174-
175-
#single pixel cluster
199+
break
200+
201+
# single pixel cluster
176202
if cluster_indices.shape[0] == 1:
177-
self.cluster_map[cluster_indices[0,0],cluster_indices[0,1]] = [0,0,0,1]
178-
179-
self.cluster_list.append(cluster_indices)
203+
self.cluster_map[cluster_indices[0, 0], cluster_indices[0, 1]] = [
204+
0,
205+
0,
206+
0,
207+
1,
208+
]
209+
210+
self.cluster_list.append(cluster_indices)
180211
cluster_count_ind += 1
181-
182-
#return cluster_count_ind, self.cluster_list, map_cluster, sim_averaged
212+
213+
# return cluster_count_ind, self.cluster_list, map_cluster, sim_averaged
183214

184215
def create_cluster_cube(
185-
self,
186-
187-
min_cluster_size,
188-
return_cluster_datacube=False,
189-
):
190-
191-
self.filtered_cluster_list = [arr for arr in self.cluster_list if arr.shape[0] >= min_cluster_size]
192-
193-
# datacube [i,j,k,l] where i is the index of the cluster, and j is a place holder, and k,l are the average diffraction pattern of the
194-
self.cluster_cube = np.empty([len(self.filtered_cluster_list),1,self.datacube.shape[2], self.datacube.shape[3]])
195-
196-
for i in tqdmnd(
197-
range(len(self.filtered_cluster_list))
198-
):
199-
self.cluster_cube[i,0] = self.datacube[np.array(self.filtered_cluster_list[i])[:, 0], np.array(self.filtered_cluster_list[i])[:, 1]].mean(axis=0)
216+
self,
217+
min_cluster_size,
218+
return_cluster_datacube=False,
219+
):
200220

201-
if return_cluster_datacube:
221+
self.filtered_cluster_list = [
222+
arr for arr in self.cluster_list if arr.shape[0] >= min_cluster_size
223+
]
224+
225+
# datacube [i,j,k,l] where i is the index of the cluster, and j is a place holder, and k,l are the average diffraction pattern of the
226+
self.cluster_cube = np.empty(
227+
[
228+
len(self.filtered_cluster_list),
229+
1,
230+
self.datacube.shape[2],
231+
self.datacube.shape[3],
232+
]
233+
)
234+
235+
for i in tqdmnd(range(len(self.filtered_cluster_list))):
236+
self.cluster_cube[i, 0] = self.datacube[
237+
np.array(self.filtered_cluster_list[i])[:, 0],
238+
np.array(self.filtered_cluster_list[i])[:, 1],
239+
].mean(axis=0)
240+
241+
if return_cluster_datacube:
202242
return self.cluster_cube, self.filtered_cluster_list
203-
204-

0 commit comments

Comments
 (0)