Skip to content

Commit c21eca4

Browse files
committed
doc strings update and delete extra code
1 parent 2ca5901 commit c21eca4

File tree

1 file changed

+86
-66
lines changed

1 file changed

+86
-66
lines changed

py4DSTEM/process/utils/cluster.py

Lines changed: 86 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,57 @@
88

99
class Cluster:
1010
"""
11-
Clustering 4D data
12-
11+
Class for clustering data in 4D-STEM DataCube based on
12+
similarity of neighboring diffraction patterns.
1313
"""
1414

1515
def __init__(
1616
self,
1717
datacube,
18-
r_space_mask,
18+
r_space_mask=None,
1919
):
2020
"""
21-
Args:
22-
datacube (py4DSTEM.DataCube): 4D-STEM data
23-
r_space_mask (np.ndarray): Mask in real space to apply background thresholding on the similarity array.
24-
21+
Parameters
22+
----------
23+
datacube: DataCube
24+
4D-STEM data
25+
r_space_mask: np.ndarray
26+
Mask in real space to apply background thresholding on the similarity array.
2527
"""
2628
self.datacube = datacube
2729
self.r_space_mask = r_space_mask
2830
self.similarity = None
2931
self.similarity_raw = None
3032

31-
def bg_thresholding(self, r_space_mask,):
32-
self.r_space_mask = np.asarray(r_space_mask)
33-
34-
# if similarity is already computed, apply the thresholding
35-
if self.similarity_raw is not None:
36-
self.similarity = self._apply_bg_mask(self.similarity_raw)
37-
3833
def _apply_bg_mask(self, similarity):
3934
if self.r_space_mask is None:
4035
return similarity
4136
return similarity * self.r_space_mask[..., None]
4237

43-
4438
def find_similarity(
45-
self,
46-
q_space_mask=None,
47-
smooth_sigma = 0,
48-
return_similarity = False
39+
self, q_space_mask=None, smooth_sigma=0, return_similarity=False
4940
):
50-
5141
"""
52-
Args:
53-
q_space_mask : annular boolean q_space_mask to apply on the diffraction patterns
54-
smooth_sigma : sigma for Gaussian smoothing of the diffraction patterns before calculating similarity
55-
return_similarity : if True, return the similarity array
42+
Find similarity to neighboring pixels
43+
44+
Parameters
45+
----------
46+
q_space_mask : np.ndarray, optional
47+
boolean q_space_mask to apply on the diffraction patterns
48+
smooth_sigma : float, optional
49+
sigma for Gaussian smoothing of the diffraction patterns
50+
before calculating similarity
51+
return_similarity : bool, optinal
52+
if True, return the similarity array
53+
54+
Returns
55+
--------
56+
similarity: np.ndarray
57+
similarity scores for each pixel
5658
"""
5759
if self.r_space_mask is None:
5860
self.set_mask(r_space_mask)
59-
61+
6062
# List of neighbors to search
6163
# (-1,-1) will be equivalent to (1,1)
6264
self.dxy = np.array(
@@ -82,12 +84,12 @@ def find_similarity(
8284
range(self.datacube.shape[0]),
8385
range(self.datacube.shape[1]),
8486
):
85-
diff_ref = self.datacube[rx, ry].copy().astype('float')
87+
diff_ref = self.datacube[rx, ry].copy().astype("float")
8688
diff_ref -= diff_ref.mean()
8789

8890
if smooth_sigma > 0:
89-
diff_ref = gaussian_filter(diff_ref,smooth_sigma)
90-
91+
diff_ref = gaussian_filter(diff_ref, smooth_sigma)
92+
9193
if q_space_mask is not None:
9294
diff_ref = diff_ref[q_space_mask]
9395

@@ -104,29 +106,28 @@ def find_similarity(
104106
and x_ind < self.datacube.shape[0]
105107
and y_ind < self.datacube.shape[1]
106108
):
107-
diff = self.datacube[x_ind, y_ind].copy().astype('float')
109+
diff = self.datacube[x_ind, y_ind].copy().astype("float")
108110
diff -= diff.mean()
109111

110112
if smooth_sigma > 0:
111-
diff = gaussian_filter(diff,smooth_sigma)
112-
113+
diff = gaussian_filter(diff, smooth_sigma)
114+
113115
if q_space_mask is not None:
114116
diff = diff[q_space_mask]
115-
117+
116118
# image self.similarity with normalized cosine correlation
117119
self.similarity[rx, ry, ind] = (
118120
np.sum(diff * diff_ref)
119121
/ np.sqrt(np.sum(diff * diff))
120122
/ norm_diff_ref
121123
)
122124

123-
124125
self.similarity_raw = self.similarity.copy()
125126
self.similarity = self._apply_bg_mask(self.similarity)
126127

127128
if return_similarity:
128129
return self.similarity
129-
130+
130131
# Create a function to map cluster index to color
131132
def get_color(self, cluster_index):
132133
colors = [
@@ -149,14 +150,19 @@ def indexing_clusters_all(
149150
threshold,
150151
):
151152
"""
152-
Args:
153-
threshold : threshold for similarity to consider two pixels as part of the same cluster
153+
Index all pixsl in a cluster
154+
155+
Parameters
156+
----------
157+
threshold: float
158+
similarity score threshold to consider pixels as part
159+
of the same cluster
154160
"""
155-
161+
156162
sim_averaged = np.mean(self.similarity, axis=2)
157163

158-
# Assigning the background as 'counted'
159-
sim_averaged[~self.r_space_mask] = -1.0
164+
# Assigning the background as 'counted'
165+
sim_averaged[~self.r_space_mask] = -1.0
160166

161167
# color the pixels with the cluster index
162168
self.cluster_map = -1 * np.ones(
@@ -165,8 +171,8 @@ def indexing_clusters_all(
165171
self.cluster_map_rgb = np.zeros(
166172
(sim_averaged.shape[0], sim_averaged.shape[1], 4), dtype=np.float64
167173
)
168-
169-
self.cluster_map_rgb[..., 3] = 1.0 #start as opaque black
174+
175+
self.cluster_map_rgb[..., 3] = 1.0 # start as opaque black
170176

171177
# store arrays of cluster_indices in a list
172178
self.cluster_list = []
@@ -182,13 +188,12 @@ def indexing_clusters_all(
182188
# finding the pixel that has the highest self.similarity among the pixel that hasn't been clustered yet
183189
# this will be the 'starting pixel' of a new cluster
184190
rx0, ry0 = np.unravel_index(sim_averaged.argmax(), sim_averaged.shape)
185-
191+
186192
# Guarding to check if the seed is background
187193
if self.r_space_mask is not None and not self.r_space_mask[rx0, ry0]:
188194
sim_averaged[rx0, ry0] = -1 # mark processed so we don't pick it again
189195
continue
190196

191-
192197
cluster_indices = np.empty((0, 2))
193198
cluster_indices = (np.append(cluster_indices, [[rx0, ry0]], axis=0)).astype(
194199
np.int32
@@ -219,54 +224,69 @@ def indexing_clusters_all(
219224
x_ind = rx0 + self.dxy[ind, 0]
220225
y_ind = ry0 + self.dxy[ind, 1]
221226

222-
if x_ind > 1 and \
223-
y_ind > 1 and \
224-
x_ind < self.similarity.shape[0] - 2 and \
225-
y_ind < self.similarity.shape[1] - 2:
227+
if (
228+
x_ind > 1
229+
and y_ind > 1
230+
and x_ind < self.similarity.shape[0] - 2
231+
and y_ind < self.similarity.shape[1] - 2
232+
):
226233

227-
r_ok = True if self.r_space_mask is None else bool(self.r_space_mask[x_ind, y_ind])
234+
r_ok = (
235+
True
236+
if self.r_space_mask is None
237+
else bool(self.r_space_mask[x_ind, y_ind])
238+
)
228239

229240
# add if the neighbor is similar, but don't add if the neighbor is already in a cluster
230-
if self.similarity[rx0, ry0, ind] >= threshold \
231-
and self.cluster_map[x_ind, y_ind] == -1 and r_ok:
241+
if (
242+
self.similarity[rx0, ry0, ind] >= threshold
243+
and self.cluster_map[x_ind, y_ind] == -1
244+
and r_ok
245+
):
232246

233-
234247
cluster_indices = np.append(
235248
cluster_indices, [[x_ind, y_ind]], axis=0
236249
)
237250

238251
self.cluster_map[x_ind, y_ind] = cluster_count_ind
239-
240252

241-
# self.cluster_map[x_ind, y_ind] = cluster_count_ind+1
242253
color = self.get_color(cluster_count_ind + 1)
243-
self.cluster_map_rgb[x_ind, y_ind] = plt.cm.colors.to_rgba(
244-
color
254+
self.cluster_map_rgb[x_ind, y_ind] = (
255+
plt.cm.colors.to_rgba(color)
245256
)
246257

247258
# if no new pixel is checked for NN then break
248259
if counting_added_pixel == 0:
249260
break
250261

251-
# single pixel cluster
252-
# if cluster_indices.shape[0] == 1:
253-
# self.cluster_map_rgb[cluster_indices[0, 0], cluster_indices[0, 1]] = [
254-
# 0,
255-
# 0,
256-
# 0,
257-
# 1,
258-
# ]
259-
260262
self.cluster_list.append(cluster_indices)
261263
cluster_count_ind += 1
262264

263-
# return cluster_count_ind, self.cluster_list, self.cluster_map, self.cluster_map_rgb
264-
265265
def create_cluster_cube(
266266
self,
267267
min_cluster_size,
268268
return_cluster_datacube=False,
269269
):
270+
"""
271+
Create dataset (N, 1, qx, qy), where N is the number of clusters
272+
that contains diffraction patterns that are averaged across pixels
273+
in each cluster
274+
275+
Parameters
276+
----------
277+
min_cluster_size: int
278+
minimum size for a clsuter to be included in dataset
279+
return_cluster_datacube: bool
280+
if True, returns clustered dataset and list of indicies
281+
of clusters
282+
283+
Returns
284+
--------
285+
cluster_cube: np.ndarray
286+
dataset with clsutered diffraction patterns
287+
filtered_cluster_list: list
288+
list of indicies in real space of each pixel of each cluster
289+
"""
270290

271291
self.filtered_cluster_list = [
272292
arr for arr in self.cluster_list if arr.shape[0] >= min_cluster_size

0 commit comments

Comments
 (0)