|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import pymatgen |
| 4 | +from scipy.ndimage import binary_erosion |
| 5 | +from py4DSTEM.process.utils import tqdmnd |
| 6 | +from scipy.ndimage import gaussian_filter |
| 7 | + |
| 8 | + |
| 9 | +class Cluster: |
| 10 | + """ |
| 11 | + Clustering 4D data |
| 12 | +
|
| 13 | + """ |
| 14 | + |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + datacube, |
| 18 | + ): |
| 19 | + """ |
| 20 | + Args: |
| 21 | + datacube (py4DSTEM.DataCube): 4D-STEM data |
| 22 | +
|
| 23 | +
|
| 24 | + """ |
| 25 | + |
| 26 | + self.datacube = datacube |
| 27 | + |
| 28 | + def find_similarity( |
| 29 | + self, |
| 30 | + mask=None, # by default |
| 31 | + ): |
| 32 | + # Which neighbors to search |
| 33 | + # (-1,-1) will be equivalent to (1,1) |
| 34 | + self.dxy = np.array( |
| 35 | + ( |
| 36 | + (-1, -1), |
| 37 | + (-1, 0), |
| 38 | + (-1, 1), |
| 39 | + (0, -1), |
| 40 | + (1, 1), |
| 41 | + (1, 0), |
| 42 | + (1, -1), |
| 43 | + (0, 1), |
| 44 | + ) |
| 45 | + ) |
| 46 | + |
| 47 | + # initialize the self.similarity array |
| 48 | + self.similarity = -1 * np.ones( |
| 49 | + (self.datacube.shape[0], self.datacube.shape[1], self.dxy.shape[0]) |
| 50 | + ) |
| 51 | + |
| 52 | + # Loop over probe positions |
| 53 | + for rx, ry in tqdmnd( |
| 54 | + range(self.datacube.shape[0]), |
| 55 | + range(self.datacube.shape[1]), |
| 56 | + ): |
| 57 | + if mask is None: |
| 58 | + diff_ref = self.datacube[rx, ry] |
| 59 | + else: |
| 60 | + diff_ref = self.datacube[rx, ry][mask] |
| 61 | + |
| 62 | + # loop over neighbors |
| 63 | + for ind in range(self.dxy.shape[0]): |
| 64 | + x_ind = rx + self.dxy[ind, 0] |
| 65 | + y_ind = ry + self.dxy[ind, 1] |
| 66 | + if ( |
| 67 | + x_ind >= 0 |
| 68 | + and y_ind >= 0 |
| 69 | + and x_ind < self.datacube.shape[0] |
| 70 | + and y_ind < self.datacube.shape[1] |
| 71 | + ): |
| 72 | + |
| 73 | + if mask is None: |
| 74 | + diff = self.datacube[x_ind, y_ind] |
| 75 | + else: |
| 76 | + diff = self.datacube[x_ind, y_ind][mask] |
| 77 | + |
| 78 | + # # image self.similarity with mean abs difference |
| 79 | + # self.similarity[rx,ry,ind] = np.mean( |
| 80 | + # np.abs( |
| 81 | + # diff - diff_ref |
| 82 | + # ) |
| 83 | + # ) |
| 84 | + |
| 85 | + # image self.similarity with normalized corr: cosine self.similarity? |
| 86 | + self.similarity[rx, ry, ind] = ( |
| 87 | + np.sum(diff * diff_ref) |
| 88 | + / np.sqrt(np.sum(diff * diff)) |
| 89 | + / np.sqrt(np.sum(diff_ref * diff_ref)) |
| 90 | + ) |
| 91 | + |
| 92 | + # Create a function to map cluster index to color |
| 93 | + def get_color(self, cluster_index): |
| 94 | + colors = [ |
| 95 | + "slategray", |
| 96 | + "lightcoral", |
| 97 | + "gold", |
| 98 | + "darkorange", |
| 99 | + "yellowgreen", |
| 100 | + "lightseagreen", |
| 101 | + "cornflowerblue", |
| 102 | + "royalblue", |
| 103 | + "lightsteelblue", |
| 104 | + "darkseagreen", |
| 105 | + ] |
| 106 | + return colors[(cluster_index - 1) % len(colors)] |
| 107 | + |
| 108 | + # Find the pixel with the highest self.similarity and start the clustering from there |
| 109 | + def indexing_clusters_all( |
| 110 | + self, |
| 111 | + mask, |
| 112 | + threshold, |
| 113 | + ): |
| 114 | + |
| 115 | + self.dxy = np.array( |
| 116 | + ( |
| 117 | + (-1, -1), |
| 118 | + (-1, 0), |
| 119 | + (-1, 1), |
| 120 | + (0, -1), |
| 121 | + (1, 1), |
| 122 | + (1, 0), |
| 123 | + (1, -1), |
| 124 | + (0, 1), |
| 125 | + ) |
| 126 | + ) |
| 127 | + |
| 128 | + sim_averaged = np.mean(self.similarity, axis=2) |
| 129 | + |
| 130 | + # color the pixels with the cluster index |
| 131 | + # map_cluster = np.zeros((sim_averaged.shape[0],sim_averaged.shape[1])) |
| 132 | + self.cluster_map = np.zeros( |
| 133 | + (sim_averaged.shape[0], sim_averaged.shape[1], 4), dtype=np.float64 |
| 134 | + ) |
| 135 | + |
| 136 | + # store arrays of cluster_indices in a list |
| 137 | + self.cluster_list = [] |
| 138 | + |
| 139 | + # incides of pixel in a cluster |
| 140 | + cluster_indices = np.empty((0, 2)) |
| 141 | + |
| 142 | + # Loop over pixels until no new pixel is found (sim_averaged is set to -1 if it is alreaddy serached for NN) |
| 143 | + cluster_count_ind = 0 |
| 144 | + |
| 145 | + while np.any(sim_averaged != -1): |
| 146 | + |
| 147 | + # finding the pixel that has the highest self.similarity among the pixel that hasn't been clustered yet |
| 148 | + # this will be the 'starting pixel' of a new cluster |
| 149 | + rx0, ry0 = np.unravel_index(sim_averaged.argmax(), sim_averaged.shape) |
| 150 | + # print(rx0, ry0) |
| 151 | + |
| 152 | + cluster_indices = np.empty((0, 2)) |
| 153 | + cluster_indices = (np.append(cluster_indices, [[rx0, ry0]], axis=0)).astype( |
| 154 | + np.int32 |
| 155 | + ) |
| 156 | + |
| 157 | + # map_cluster[rx0, ry0] = cluster_count_ind+1 |
| 158 | + color = self.get_color(cluster_count_ind + 1) |
| 159 | + self.cluster_map[rx0, ry0] = plt.cm.colors.to_rgba(color) |
| 160 | + |
| 161 | + # Clustering: one cluster per while loop(until it breaks) |
| 162 | + # Marching algorithm: find a new position and search the nearest neighbor |
| 163 | + |
| 164 | + while True: |
| 165 | + counting_added_pixel = 0 |
| 166 | + |
| 167 | + for rx0, ry0 in cluster_indices: |
| 168 | + |
| 169 | + if sim_averaged[rx0, ry0] != -1: |
| 170 | + |
| 171 | + # counter to check if pixel in the cluster are checked for NN |
| 172 | + counting_added_pixel += 1 |
| 173 | + |
| 174 | + # set to -1 as its NN will be checked |
| 175 | + sim_averaged[rx0, ry0] = -1 |
| 176 | + |
| 177 | + for ind in range(self.dxy.shape[0]): |
| 178 | + x_ind = rx0 + self.dxy[ind, 0] |
| 179 | + y_ind = ry0 + self.dxy[ind, 1] |
| 180 | + |
| 181 | + # add if the neighbor is similar, but don't add if the neighbor is already in a cluster |
| 182 | + if self.similarity[ |
| 183 | + rx0, ry0, ind |
| 184 | + ] > threshold and np.array_equal( |
| 185 | + self.cluster_map[x_ind, y_ind], [0, 0, 0, 0] |
| 186 | + ): |
| 187 | + |
| 188 | + cluster_indices = np.append( |
| 189 | + cluster_indices, [[x_ind, y_ind]], axis=0 |
| 190 | + ) |
| 191 | + # self.cluster_map[x_ind, y_ind] = cluster_count_ind+1 |
| 192 | + color = self.get_color(cluster_count_ind + 1) |
| 193 | + self.cluster_map[x_ind, y_ind] = plt.cm.colors.to_rgba( |
| 194 | + color |
| 195 | + ) |
| 196 | + |
| 197 | + # if no new pixel is checked for NN then break |
| 198 | + if counting_added_pixel == 0: |
| 199 | + break |
| 200 | + |
| 201 | + # single pixel cluster |
| 202 | + if cluster_indices.shape[0] == 1: |
| 203 | + self.cluster_map[cluster_indices[0, 0], cluster_indices[0, 1]] = [ |
| 204 | + 0, |
| 205 | + 0, |
| 206 | + 0, |
| 207 | + 1, |
| 208 | + ] |
| 209 | + |
| 210 | + self.cluster_list.append(cluster_indices) |
| 211 | + cluster_count_ind += 1 |
| 212 | + |
| 213 | + # return cluster_count_ind, self.cluster_list, map_cluster, sim_averaged |
| 214 | + |
| 215 | + def create_cluster_cube( |
| 216 | + self, |
| 217 | + min_cluster_size, |
| 218 | + return_cluster_datacube=False, |
| 219 | + ): |
| 220 | + |
| 221 | + self.filtered_cluster_list = [ |
| 222 | + arr for arr in self.cluster_list if arr.shape[0] >= min_cluster_size |
| 223 | + ] |
| 224 | + |
| 225 | + # datacube [i,j,k,l] where i is the index of the cluster, and j is a place holder, and k,l are the average diffraction pattern of the |
| 226 | + self.cluster_cube = np.empty( |
| 227 | + [ |
| 228 | + len(self.filtered_cluster_list), |
| 229 | + 1, |
| 230 | + self.datacube.shape[2], |
| 231 | + self.datacube.shape[3], |
| 232 | + ] |
| 233 | + ) |
| 234 | + |
| 235 | + for i in tqdmnd(range(len(self.filtered_cluster_list))): |
| 236 | + self.cluster_cube[i, 0] = self.datacube[ |
| 237 | + np.array(self.filtered_cluster_list[i])[:, 0], |
| 238 | + np.array(self.filtered_cluster_list[i])[:, 1], |
| 239 | + ].mean(axis=0) |
| 240 | + |
| 241 | + if return_cluster_datacube: |
| 242 | + return self.cluster_cube, self.filtered_cluster_list |
0 commit comments