|
| 1 | +import math |
1 | 2 | import multiprocessing as mp |
2 | 3 | from concurrent import futures |
3 | | -from typing import Callable, Tuple, Optional |
| 4 | +from typing import Callable, List, Optional, Tuple |
4 | 5 |
|
5 | 6 | import elf.parallel as parallel |
6 | 7 | import numpy as np |
7 | 8 | import nifty.tools as nt |
| 9 | +import networkx as nx |
8 | 10 | import pandas as pd |
9 | 11 |
|
10 | 12 | from elf.io import open_file |
11 | | -from scipy.spatial import distance |
12 | 13 | from scipy.sparse import csr_matrix |
| 14 | +from scipy.spatial import distance |
13 | 15 | from scipy.spatial import cKDTree, ConvexHull |
14 | 16 | from skimage import measure |
15 | 17 | from sklearn.neighbors import NearestNeighbors |
@@ -205,3 +207,261 @@ def filter_chunk(block_id): |
205 | 207 | ) |
206 | 208 |
|
207 | 209 | return n_ids, n_ids_filtered |
| 210 | + |
| 211 | + |
| 212 | +def erode_subset( |
| 213 | + table: pd.DataFrame, |
| 214 | + iterations: int = 1, |
| 215 | + min_cells: Optional[int] = None, |
| 216 | + threshold: int = 35, |
| 217 | + keyword: str = "distance_nn100", |
| 218 | +) -> pd.DataFrame: |
| 219 | + """Erode coordinates of dataframe according to a keyword and a threshold. |
| 220 | + Use a copy of the dataframe as an input, if it should not be edited. |
| 221 | +
|
| 222 | + Args: |
| 223 | + table: Dataframe of segmentation table. |
| 224 | + iterations: Number of steps for erosion process. |
| 225 | + min_cells: Minimal number of rows. The erosion is stopped after falling below this limit. |
| 226 | + threshold: Upper threshold for removing elements according to the given keyword. |
| 227 | + keyword: Keyword of dataframe for erosion. |
| 228 | +
|
| 229 | + Returns: |
| 230 | + The dataframe containing elements left after the erosion. |
| 231 | + """ |
| 232 | + print("initial length", len(table)) |
| 233 | + n_neighbors = 100 |
| 234 | + for i in range(iterations): |
| 235 | + table = table[table[keyword] < threshold] |
| 236 | + |
| 237 | + distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors) |
| 238 | + |
| 239 | + if min_cells is not None and len(distance_avg) < min_cells: |
| 240 | + print(f"{i}-th iteration, length of subset {len(table)}, stopping erosion") |
| 241 | + break |
| 242 | + |
| 243 | + table.loc[:, 'distance_nn'+str(n_neighbors)] = list(distance_avg) |
| 244 | + |
| 245 | + print(f"{i}-th iteration, length of subset {len(table)}") |
| 246 | + |
| 247 | + return table |
| 248 | + |
| 249 | + |
| 250 | +def downscaled_centroids( |
| 251 | + table: pd.DataFrame, |
| 252 | + scale_factor: int, |
| 253 | + ref_dimensions: Optional[Tuple[float, float, float]] = None, |
| 254 | + downsample_mode: str = "accumulated", |
| 255 | +) -> np.typing.NDArray: |
| 256 | + """Downscale centroids in dataframe. |
| 257 | +
|
| 258 | + Args: |
| 259 | + table: Dataframe of segmentation table. |
| 260 | + scale_factor: Factor for downscaling coordinates. |
| 261 | + ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied. |
| 262 | + downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'. |
| 263 | +
|
| 264 | + Returns: |
| 265 | + The downscaled array |
| 266 | + """ |
| 267 | + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) |
| 268 | + centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids] |
| 269 | + |
| 270 | + if ref_dimensions is None: |
| 271 | + bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"])) |
| 272 | + bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions]) |
| 273 | + new_array = np.zeros(bounding_dimensions_scaled) |
| 274 | + |
| 275 | + else: |
| 276 | + bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions]) |
| 277 | + new_array = np.zeros(bounding_dimensions_scaled) |
| 278 | + |
| 279 | + if downsample_mode == "accumulated": |
| 280 | + for c in centroids_scaled: |
| 281 | + new_array[int(c[0]), int(c[1]), int(c[2])] += 1 |
| 282 | + |
| 283 | + elif downsample_mode == "capped": |
| 284 | + for c in centroids_scaled: |
| 285 | + new_array[int(c[0]), int(c[1]), int(c[2])] = 1 |
| 286 | + |
| 287 | + elif downsample_mode == "components": |
| 288 | + if "component_labels" not in table.columns: |
| 289 | + raise KeyError("Dataframe must continue key 'component_labels' for downsampling with mode 'components'.") |
| 290 | + component_labels = list(table["component_labels"]) |
| 291 | + for comp, centr in zip(component_labels, centroids_scaled): |
| 292 | + if comp != 0: |
| 293 | + new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp |
| 294 | + |
| 295 | + else: |
| 296 | + raise ValueError("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'.") |
| 297 | + |
| 298 | + new_array = np.round(new_array).astype(int) |
| 299 | + |
| 300 | + return new_array |
| 301 | + |
| 302 | + |
| 303 | +def components_sgn( |
| 304 | + table: pd.DataFrame, |
| 305 | + keyword: str = "distance_nn100", |
| 306 | + threshold_erode: Optional[float] = None, |
| 307 | + postprocess_graph: bool = False, |
| 308 | + min_component_length: int = 50, |
| 309 | + min_edge_distance: float = 30, |
| 310 | + iterations_erode: Optional[int] = None, |
| 311 | +) -> List[List[int]]: |
| 312 | + """Eroding the SGN segmentation. |
| 313 | +
|
| 314 | + Args: |
| 315 | + table: Dataframe of segmentation table. |
| 316 | + keyword: Keyword of the dataframe column for erosion. |
| 317 | + threshold_erode: Threshold of column value after erosion step with spatial statistics. |
| 318 | + postprocess_graph: Post-process graph connected components by searching for near points. |
| 319 | + min_component_length: Minimal length for filtering out connected components. |
| 320 | + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. |
| 321 | + iterations_erode: Number of iterations for erosion, normally determined automatically. |
| 322 | +
|
| 323 | + Returns: |
| 324 | + Subgraph components as lists of label_ids of dataframe. |
| 325 | + """ |
| 326 | + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) |
| 327 | + labels = [int(i) for i in list(table["label_id"])] |
| 328 | + |
| 329 | + distance_nn = list(table[keyword]) |
| 330 | + distance_nn.sort() |
| 331 | + |
| 332 | + if len(table) < 20000: |
| 333 | + iterations = iterations_erode if iterations_erode is not None else 0 |
| 334 | + min_cells = None |
| 335 | + average_dist = int(distance_nn[int(len(table) * 0.8)]) |
| 336 | + threshold = threshold_erode if threshold_erode is not None else average_dist |
| 337 | + else: |
| 338 | + iterations = iterations_erode if iterations_erode is not None else 15 |
| 339 | + min_cells = 20000 |
| 340 | + threshold = threshold_erode if threshold_erode is not None else 40 |
| 341 | + |
| 342 | + print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.") |
| 343 | + |
| 344 | + new_subset = erode_subset(table.copy(), iterations=iterations, |
| 345 | + threshold=threshold, min_cells=min_cells, keyword=keyword) |
| 346 | + |
| 347 | + # create graph from coordinates of eroded subset |
| 348 | + centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) |
| 349 | + labels_subset = [int(i) for i in list(new_subset["label_id"])] |
| 350 | + coords = {} |
| 351 | + for index, element in zip(labels_subset, centroids_subset): |
| 352 | + coords[index] = element |
| 353 | + |
| 354 | + graph = nx.Graph() |
| 355 | + for num, pos in coords.items(): |
| 356 | + graph.add_node(num, pos=pos) |
| 357 | + |
| 358 | + # create edges between points whose distance is less than threshold min_edge_distance |
| 359 | + for i in coords: |
| 360 | + for j in coords: |
| 361 | + if i < j: |
| 362 | + dist = math.dist(coords[i], coords[j]) |
| 363 | + if dist <= min_edge_distance: |
| 364 | + graph.add_edge(i, j, weight=dist) |
| 365 | + |
| 366 | + components = list(nx.connected_components(graph)) |
| 367 | + |
| 368 | + # remove connected components with less nodes than threshold min_component_length |
| 369 | + for component in components: |
| 370 | + if len(component) < min_component_length: |
| 371 | + for c in component: |
| 372 | + graph.remove_node(c) |
| 373 | + |
| 374 | + components = [list(s) for s in nx.connected_components(graph)] |
| 375 | + |
| 376 | + # add original coordinates closer to eroded component than threshold |
| 377 | + if postprocess_graph: |
| 378 | + threshold = 15 |
| 379 | + for label_id, centr in zip(labels, centroids): |
| 380 | + if label_id not in labels_subset: |
| 381 | + add_coord = [] |
| 382 | + for comp_index, component in enumerate(components): |
| 383 | + for comp_label in component: |
| 384 | + dist = math.dist(centr, centroids[comp_label - 1]) |
| 385 | + if dist <= threshold: |
| 386 | + add_coord.append([comp_index, label_id]) |
| 387 | + break |
| 388 | + if len(add_coord) != 0: |
| 389 | + components[add_coord[0][0]].append(add_coord[0][1]) |
| 390 | + |
| 391 | + return components |
| 392 | + |
| 393 | + |
| 394 | +def label_components( |
| 395 | + table: pd.DataFrame, |
| 396 | + min_size: int = 1000, |
| 397 | + threshold_erode: Optional[float] = None, |
| 398 | + min_component_length: int = 50, |
| 399 | + min_edge_distance: float = 30, |
| 400 | + iterations_erode: Optional[int] = None, |
| 401 | +) -> List[int]: |
| 402 | + """Label components using graph connected components. |
| 403 | +
|
| 404 | + Args: |
| 405 | + table: Dataframe of segmentation table. |
| 406 | + min_size: Minimal number of pixels for filtering small instances. |
| 407 | + threshold_erode: Threshold of column value after erosion step with spatial statistics. |
| 408 | + min_component_length: Minimal length for filtering out connected components. |
| 409 | + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. |
| 410 | + iterations_erode: Number of iterations for erosion, normally determined automatically. |
| 411 | +
|
| 412 | + Returns: |
| 413 | + List of component label for each point in dataframe. 0 - background, then in descending order of size |
| 414 | + """ |
| 415 | + |
| 416 | + # First, apply the size filter. |
| 417 | + entries_filtered = table[table.n_pixels < min_size] |
| 418 | + table = table[table.n_pixels >= min_size] |
| 419 | + |
| 420 | + components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length, |
| 421 | + min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) |
| 422 | + |
| 423 | + # add size-filtered objects to have same initial length |
| 424 | + table = pd.concat([table, entries_filtered], ignore_index=True) |
| 425 | + table.sort_values("label_id") |
| 426 | + |
| 427 | + length_components = [len(c) for c in components] |
| 428 | + length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) |
| 429 | + |
| 430 | + component_labels = [0 for _ in range(len(table))] |
| 431 | + # be aware of 'label_id' of dataframe starting at 1 |
| 432 | + for lab, comp in enumerate(components): |
| 433 | + for comp_index in comp: |
| 434 | + component_labels[comp_index - 1] = lab + 1 |
| 435 | + |
| 436 | + return component_labels |
| 437 | + |
| 438 | + |
| 439 | +def postprocess_sgn_seg( |
| 440 | + table: pd.DataFrame, |
| 441 | + min_size: int = 1000, |
| 442 | + threshold_erode: Optional[float] = None, |
| 443 | + min_component_length: int = 50, |
| 444 | + min_edge_distance: float = 30, |
| 445 | + iterations_erode: Optional[int] = None, |
| 446 | +) -> pd.DataFrame: |
| 447 | + """Postprocessing SGN segmentation of cochlea. |
| 448 | +
|
| 449 | + Args: |
| 450 | + table: Dataframe of segmentation table. |
| 451 | + min_size: Minimal number of pixels for filtering small instances. |
| 452 | + threshold_erode: Threshold of column value after erosion step with spatial statistics. |
| 453 | + min_component_length: Minimal length for filtering out connected components. |
| 454 | + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. |
| 455 | + iterations_erode: Number of iterations for erosion, normally determined automatically. |
| 456 | +
|
| 457 | + Returns: |
| 458 | + Dataframe with component labels. |
| 459 | + """ |
| 460 | + |
| 461 | + comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode, |
| 462 | + min_component_length=min_component_length, |
| 463 | + min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) |
| 464 | + |
| 465 | + table.loc[:, "component_labels"] = comp_labels |
| 466 | + |
| 467 | + return table |
0 commit comments