|
| 1 | +"""Class which adapts clustering labels given upstream semantic predictions.""" |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from torch_cluster import knn |
| 6 | +from scipy.spatial.distance import cdist |
| 7 | + |
| 8 | +from spine.data import TensorBatch |
| 9 | + |
| 10 | +from spine.utils.gnn.cluster import form_clusters, break_clusters |
| 11 | +from spine.utils.globals import ( |
| 12 | + COORD_COLS, VALUE_COL, CLUST_COL, SHAPE_COL, SHOWR_SHP, TRACK_SHP, |
| 13 | + MICHL_SHP, DELTA_SHP, GHOST_SHP) |
| 14 | + |
| 15 | +__all__ = ['ClusterLabelAdapter'] |
| 16 | + |
| 17 | + |
| 18 | +class ClusterLabelAdapter: |
| 19 | + """Adapts the cluster labels to account for the predicted semantics. |
| 20 | +
|
| 21 | + Points wrongly predicted get the cluster label of the closest touching |
| 22 | + cluster, if there is one. Points that are predicted as ghosts get invalid |
| 23 | + (-1) cluster labels everywhere. |
| 24 | +
|
| 25 | + Instances that have been broken up by the deghosting or semantic |
| 26 | + segmentation process get assigned distinct cluster labels for each |
| 27 | + effective fragment, provided they appearing in the `break_classes` list. |
| 28 | +
|
| 29 | + Notes |
| 30 | + ----- |
| 31 | + This class supports both Numpy arrays and Torch tensors. |
| 32 | +
|
| 33 | + It uses the GPU implementation from `torch_cluster.knn` to speed up the |
| 34 | + label adaptation computation (instead of cdist). |
| 35 | +
|
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, break_eps=1.1, break_metric='chebyshev', |
| 39 | + break_classes=[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP]): |
| 40 | + """Initialize the adapter class. |
| 41 | +
|
| 42 | + Parameters |
| 43 | + ---------- |
| 44 | + dtype : str, default 'torch' |
| 45 | + Type of data to be processed through the label adapter |
| 46 | + break_eps : float, default 1.1 |
| 47 | + Distance scale used in the break up procedure |
| 48 | + break_metric : str, default 'chebyshev' |
| 49 | + Distance metric used in the break up produce |
| 50 | + break_classes : List[int], default |
| 51 | + [SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP] |
| 52 | + Classes to run DBSCAN on to break up |
| 53 | + """ |
| 54 | + # Store relevant parameters |
| 55 | + self.break_eps = break_eps |
| 56 | + self.break_metric = break_metric |
| 57 | + self.break_classes = break_classes |
| 58 | + |
| 59 | + # Attributes used to fetch the correct functions |
| 60 | + self.torch, self.dtype, self.device = None, None, None |
| 61 | + |
| 62 | + def __call__(self, clust_label, seg_label, seg_pred, ghost_pred=None): |
| 63 | + """Adapts the cluster labels for one entry or a batch of entries. |
| 64 | +
|
| 65 | + Parameters |
| 66 | + ---------- |
| 67 | + clust_label : Union[TensorBatch, np.ndarray, torch.Tensor] |
| 68 | + (N, N_l) Cluster label tensor |
| 69 | + seg_label : Union[TensorBatch, np.ndarray, torch.Tensor] |
| 70 | + (M, 5) Segmentation label tensor |
| 71 | + seg_pred : Union[TensorBatch, np.ndarray, torch.Tensor] |
| 72 | + (M/N_deghost) Segmentation predictions for each voxel |
| 73 | + ghost_pred : Union[TensorBatch, np.ndarray, torch.Tensor], optional |
| 74 | + (M) Ghost predictions for each voxel |
| 75 | +
|
| 76 | + Returns |
| 77 | + ------- |
| 78 | + Union[TensorBatch, np.ndarray, torch.Tensor] |
| 79 | + (N_deghost, N_l) Adapted cluster label tensor |
| 80 | + """ |
| 81 | + # Set the data type/device based on the input |
| 82 | + ref_tensor = clust_label |
| 83 | + if isinstance(ref_tensor, TensorBatch): |
| 84 | + ref_tensor = ref_tensor.tensor |
| 85 | + self.torch = isinstance(ref_tensor, torch.Tensor) |
| 86 | + |
| 87 | + self.dtype = clust_label.dtype |
| 88 | + if self.torch: |
| 89 | + self.device = clust_label.device |
| 90 | + |
| 91 | + # Dispatch depending on the data type |
| 92 | + if isinstance(clust_label, TensorBatch): |
| 93 | + # If it is batch data, call the main process function of each entry |
| 94 | + shape = (seg_pred.shape[0], clust_label.shape[1]) |
| 95 | + clust_label_adapted = torch.empty( |
| 96 | + shape, dtype=clust_label.dtype, device=clust_label.device) |
| 97 | + for b in range(clust_label.batch_size): |
| 98 | + lower, upper = seg_pred.edges[b], seg_pred.edges[b+1] |
| 99 | + ghost_pred_b = ghost_pred[b] if ghost_pred is not None else None |
| 100 | + clust_label_adapted[lower:upper] = self._process( |
| 101 | + clust_label[b], seg_label[b], seg_pred[b], ghost_pred_b) |
| 102 | + |
| 103 | + return TensorBatch(clust_label_adapted, seg_pred.counts) |
| 104 | + |
| 105 | + else: |
| 106 | + # Otherwise, call the main process function directly |
| 107 | + return self._process(clust_label, seg_label, seg_pred, ghost_pred) |
| 108 | + |
| 109 | + def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None): |
| 110 | + """Adapts the cluster labels for one entry or a batch of entries. |
| 111 | +
|
| 112 | + Parameters |
| 113 | + ---------- |
| 114 | + clust_label : Union[np.ndarray, torch.Tensor] |
| 115 | + (N, N_l) Cluster label tensor |
| 116 | + seg_label : Union[np.ndarray, torch.Tensor] |
| 117 | + (M, 5) Segmentation label tensor |
| 118 | + seg_pred : Union[np.ndarray, torch.Tensor] |
| 119 | + (M/N_deghost) Segmentation predictions for each voxel |
| 120 | + ghost_pred : Union[np.ndarray, torch.Tensor], optional |
| 121 | + (M) Ghost predictions for each voxel |
| 122 | +
|
| 123 | + Returns |
| 124 | + ------- |
| 125 | + Union[np.ndarray, torch.Tensor] |
| 126 | + (N_deghost, N_l) Adapted cluster label tensor |
| 127 | + """ |
| 128 | + # If there are no points in this event, nothing to do |
| 129 | + coords = seg_label[:, :VALUE_COL] |
| 130 | + num_cols = clust_label.shape[1] |
| 131 | + if not len(coords): |
| 132 | + return self._ones((0, num_cols)) |
| 133 | + |
| 134 | + # If there are no points after deghosting, nothing to do |
| 135 | + if ghost_pred is not None: |
| 136 | + deghost_index = self._where(ghost_pred == 0)[0] |
| 137 | + if not len(deghost_index): |
| 138 | + return self._ones((0, num_cols)) |
| 139 | + |
| 140 | + # If there are no label points in this event, return dummy labels |
| 141 | + if not len(clust_label): |
| 142 | + if ghost_pred is None: |
| 143 | + shape = (len(coords), num_cols) |
| 144 | + dummy_labels = -self._ones(shape) |
| 145 | + dummy_labels[:, :VALUE_COL] = coords |
| 146 | + |
| 147 | + else: |
| 148 | + shape = (len(deghost_index), num_cols) |
| 149 | + dummy_labels = -self._ones(shape) |
| 150 | + dummy_labels[:, :VALUE_COL] = coords[deghost_index] |
| 151 | + |
| 152 | + return dummy_labels |
| 153 | + |
| 154 | + # Build a tensor of predicted segmentation that includes ghost points |
| 155 | + seg_label = self._to_long(seg_label[:, SHAPE_COL]) |
| 156 | + if ghost_pred is not None and (len(ghost_pred) != len(seg_pred)): |
| 157 | + seg_pred_long = self._to_long(GHOST_SHP*self._ones(len(coords))) |
| 158 | + seg_pred_long[deghost_index] = seg_pred |
| 159 | + seg_pred = seg_pred_long |
| 160 | + |
| 161 | + # Prepare new labels |
| 162 | + new_label = -self._ones((len(coords), num_cols)) |
| 163 | + new_label[:, :VALUE_COL] = coords |
| 164 | + |
| 165 | + # Check if the segment labels and predictions are compatible. If they are |
| 166 | + # compatible, store the cluster labels as is. Track points do not mix |
| 167 | + # with other classes, but EM classes are allowed to. |
| 168 | + compat_mat = self._eye(GHOST_SHP + 1) |
| 169 | + compat_mat[([SHOWR_SHP, SHOWR_SHP, MICHL_SHP, DELTA_SHP], |
| 170 | + [MICHL_SHP, DELTA_SHP, SHOWR_SHP, SHOWR_SHP])] = True |
| 171 | + |
| 172 | + true_deghost = seg_label < GHOST_SHP |
| 173 | + seg_mismatch = ~compat_mat[(seg_pred, seg_label)] |
| 174 | + new_label[true_deghost] = clust_label |
| 175 | + new_label[true_deghost & seg_mismatch, VALUE_COL:] = -self._ones(1) |
| 176 | + |
| 177 | + # For mismatched predictions, attempt to find a touching instance of the |
| 178 | + # same class to assign it sensible cluster labels. |
| 179 | + for s in self._unique(seg_pred): |
| 180 | + # Skip predicted ghosts (they keep their invalid labels) |
| 181 | + if s == GHOST_SHP: |
| 182 | + continue |
| 183 | + |
| 184 | + # Restrict to points in this class that have incompatible segment |
| 185 | + # labels. Track points do not mix, EM points are allowed to. |
| 186 | + bad_index = self._where( |
| 187 | + (seg_pred == s) & (~true_deghost | seg_mismatch))[0] |
| 188 | + if len(bad_index) == 0: |
| 189 | + continue |
| 190 | + |
| 191 | + # Find points in clust_label that have compatible segment labels |
| 192 | + seg_clust_mask = compat_mat[s][self._to_long(clust_label[:, SHAPE_COL])] |
| 193 | + X_true = clust_label[seg_clust_mask] |
| 194 | + if len(X_true) == 0: |
| 195 | + continue |
| 196 | + |
| 197 | + # Loop over the set of unlabeled predicted points |
| 198 | + X_pred = coords[bad_index] |
| 199 | + tagged_voxels_count = 1 |
| 200 | + while tagged_voxels_count > 0 and len(X_pred) > 0: |
| 201 | + # Find the nearest neighbor to each predicted point |
| 202 | + closest_ids = self._compute_neighbor(X_pred, X_true) |
| 203 | + |
| 204 | + # Compute Chebyshev distance between predicted and closest true. |
| 205 | + distances = self._compute_distances(X_pred, X_true[closest_ids]) |
| 206 | + |
| 207 | + # Label unlabeled voxels that touch a compatible true voxel |
| 208 | + select_mask = distances < 1.1 |
| 209 | + select_index = self._where(select_mask)[0] |
| 210 | + tagged_voxels_count = len(select_index) |
| 211 | + if tagged_voxels_count > 0: |
| 212 | + # Use the label of the touching true voxel |
| 213 | + additional_clust_label = self._cat( |
| 214 | + [X_pred[select_index], |
| 215 | + X_true[closest_ids[select_index], VALUE_COL:]], 1) |
| 216 | + new_label[bad_index[select_index]] = additional_clust_label |
| 217 | + |
| 218 | + # Update the mask to not include the new assigned points |
| 219 | + leftover_index = self._where(~select_mask)[0] |
| 220 | + bad_index = bad_index[leftover_index] |
| 221 | + |
| 222 | + # The new true available points are the ones we just added. |
| 223 | + # The new pred points are those not yet labeled |
| 224 | + X_true = additional_clust_label |
| 225 | + X_pred = X_pred[leftover_index] |
| 226 | + |
| 227 | + # Remove predicted ghost points from the labels, set the shape |
| 228 | + # column of the label to the segmentation predictions. |
| 229 | + if ghost_pred is not None: |
| 230 | + new_label = new_label[deghost_index] |
| 231 | + new_label[:, SHAPE_COL] = seg_pred[deghost_index] |
| 232 | + else: |
| 233 | + new_label[:, SHAPE_COL] = seg_pred |
| 234 | + |
| 235 | + # Build a list of cluster indexes to break |
| 236 | + new_label_np = new_label |
| 237 | + if torch.is_tensor(new_label): |
| 238 | + new_label_np = new_label.detach().cpu().numpy() |
| 239 | + |
| 240 | + clusts = [] |
| 241 | + labels = new_label_np[:, CLUST_COL] |
| 242 | + shapes = new_label_np[:, SHAPE_COL] |
| 243 | + for break_class in self.break_classes: |
| 244 | + index_s = np.where(shapes == break_class)[0] |
| 245 | + labels_s = labels[index_s] |
| 246 | + for c in np.unique(labels_s): |
| 247 | + # If the cluster ID is invalid, skip |
| 248 | + if c < 0: |
| 249 | + continue |
| 250 | + |
| 251 | + # Append cluster |
| 252 | + clusts.append(index_s[labels_s == c]) |
| 253 | + |
| 254 | + # Now if an instance was broken up, assign it different cluster IDs |
| 255 | + new_label[:, CLUST_COL] = break_clusters( |
| 256 | + new_label, clusts, self.break_eps, self.break_metric) |
| 257 | + |
| 258 | + return new_label |
| 259 | + |
| 260 | + def _where(self, x): |
| 261 | + if self.torch: |
| 262 | + return torch.where(x) |
| 263 | + else: |
| 264 | + return np.where(x) |
| 265 | + |
| 266 | + def _cat(self, x, axis): |
| 267 | + if self.torch: |
| 268 | + return torch.cat(x, axis) |
| 269 | + else: |
| 270 | + return np.concatenate(x, axis) |
| 271 | + |
| 272 | + def _ones(self, x): |
| 273 | + if self.torch: |
| 274 | + return torch.ones(x, dtype=self.dtype, device=self.device) |
| 275 | + else: |
| 276 | + return np.ones(x) |
| 277 | + |
| 278 | + def _eye(self, x): |
| 279 | + if self.torch: |
| 280 | + return torch.eye(x, dtype=torch.bool, device=self.device) |
| 281 | + else: |
| 282 | + return np.eye(x, dtype=bool) |
| 283 | + |
| 284 | + def _unique(self, x): |
| 285 | + if self.torch: |
| 286 | + return torch.unique(x).long() |
| 287 | + else: |
| 288 | + return np.unique(x).astype(np.int64) |
| 289 | + |
| 290 | + def _to_long(self, x): |
| 291 | + if self.torch: |
| 292 | + return x.long() |
| 293 | + else: |
| 294 | + return x.astype(int64) |
| 295 | + |
| 296 | + def _compute_neighbor(self, x, y): |
| 297 | + if self.torch: |
| 298 | + return knn(y[:, COORD_COLS], x[:, COORD_COLS], 1)[1] |
| 299 | + else: |
| 300 | + return cdist(x[:, COORD_COLS], y[:, COORD_COLS]).argmin(axis=1) |
| 301 | + |
| 302 | + def _compute_distances(self, x, y): |
| 303 | + if self.torch: |
| 304 | + return torch.amax(torch.abs(y[:, COORD_COLS] - x[:, COORD_COLS]), dim=1) |
| 305 | + else: |
| 306 | + return np.amax(np.abs(x[:, COORD_COLS] - y[:, COORD_COLS]), axis=1) |
0 commit comments