Skip to content

Commit 68fba2f

Browse files
authored
Merge pull request #26 from computational-cell-analytics/postprocessing_with_connected_components
Postprocessing cochlea segmentation using erosion
2 parents 4af55c2 + 6c4e504 commit 68fba2f

File tree

6 files changed

+362
-49
lines changed

6 files changed

+362
-49
lines changed

environment.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ dependencies:
1313
- s3fs
1414
- torch_em
1515
- z5py
16-
- zarr
16+
# Don't install zarr v3, as we are not sure that it is compatible with MoBIE etc. yet
17+
- zarr <3

flamingo_tools/file_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import zarr
88
from elf.io import open_file
99

10+
try:
11+
from zarr.abc.store import Store
12+
except ImportError:
13+
from zarr._storage.store import BaseStore as Store
14+
1015

1116
def _parse_shape(metadata_file):
1217
depth, height, width = None, None, None
@@ -62,7 +67,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]:
6267
return x
6368

6469

65-
def read_image_data(input_path: Union[str, zarr.storage.FSStore], input_key: Optional[str]) -> np.typing.ArrayLike:
70+
def read_image_data(input_path: Union[str, Store], input_key: Optional[str]) -> np.typing.ArrayLike:
6671
"""Read flamingo image data, stored in various formats.
6772
6873
Args:

flamingo_tools/s3_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import s3fs
88
import zarr
99

10+
try:
11+
from zarr.abc.store import Store
12+
except ImportError:
13+
from zarr._storage.store import BaseStore as Store
14+
1015

1116
# Dedicated bucket for cochlea lightsheet project
1217
MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet"
@@ -93,7 +98,7 @@ def get_s3_path(
9398
bucket_name: Optional[str] = None,
9499
service_endpoint: Optional[str] = None,
95100
credential_file: Optional[str] = None,
96-
) -> Tuple[zarr.storage.FSStore, s3fs.core.S3FileSystem]:
101+
) -> Tuple[Store, s3fs.core.S3FileSystem]:
97102
"""Get S3 path for a file or folder and file system based on S3 parameters and credentials.
98103
99104
Args:

flamingo_tools/segmentation/postprocessing.py

Lines changed: 262 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
import math
12
import multiprocessing as mp
23
from concurrent import futures
3-
from typing import Callable, Tuple, Optional
4+
from typing import Callable, List, Optional, Tuple
45

56
import elf.parallel as parallel
67
import numpy as np
78
import nifty.tools as nt
9+
import networkx as nx
810
import pandas as pd
911

1012
from elf.io import open_file
11-
from scipy.spatial import distance
1213
from scipy.sparse import csr_matrix
14+
from scipy.spatial import distance
1315
from scipy.spatial import cKDTree, ConvexHull
1416
from skimage import measure
1517
from sklearn.neighbors import NearestNeighbors
@@ -205,3 +207,261 @@ def filter_chunk(block_id):
205207
)
206208

207209
return n_ids, n_ids_filtered
210+
211+
212+
def erode_subset(
213+
table: pd.DataFrame,
214+
iterations: int = 1,
215+
min_cells: Optional[int] = None,
216+
threshold: int = 35,
217+
keyword: str = "distance_nn100",
218+
) -> pd.DataFrame:
219+
"""Erode coordinates of dataframe according to a keyword and a threshold.
220+
Use a copy of the dataframe as an input, if it should not be edited.
221+
222+
Args:
223+
table: Dataframe of segmentation table.
224+
iterations: Number of steps for erosion process.
225+
min_cells: Minimal number of rows. The erosion is stopped after falling below this limit.
226+
threshold: Upper threshold for removing elements according to the given keyword.
227+
keyword: Keyword of dataframe for erosion.
228+
229+
Returns:
230+
The dataframe containing elements left after the erosion.
231+
"""
232+
print("initial length", len(table))
233+
n_neighbors = 100
234+
for i in range(iterations):
235+
table = table[table[keyword] < threshold]
236+
237+
distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors)
238+
239+
if min_cells is not None and len(distance_avg) < min_cells:
240+
print(f"{i}-th iteration, length of subset {len(table)}, stopping erosion")
241+
break
242+
243+
table.loc[:, 'distance_nn'+str(n_neighbors)] = list(distance_avg)
244+
245+
print(f"{i}-th iteration, length of subset {len(table)}")
246+
247+
return table
248+
249+
250+
def downscaled_centroids(
251+
table: pd.DataFrame,
252+
scale_factor: int,
253+
ref_dimensions: Optional[Tuple[float, float, float]] = None,
254+
downsample_mode: str = "accumulated",
255+
) -> np.typing.NDArray:
256+
"""Downscale centroids in dataframe.
257+
258+
Args:
259+
table: Dataframe of segmentation table.
260+
scale_factor: Factor for downscaling coordinates.
261+
ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
262+
downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'.
263+
264+
Returns:
265+
The downscaled array
266+
"""
267+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
268+
centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids]
269+
270+
if ref_dimensions is None:
271+
bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
272+
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions])
273+
new_array = np.zeros(bounding_dimensions_scaled)
274+
275+
else:
276+
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions])
277+
new_array = np.zeros(bounding_dimensions_scaled)
278+
279+
if downsample_mode == "accumulated":
280+
for c in centroids_scaled:
281+
new_array[int(c[0]), int(c[1]), int(c[2])] += 1
282+
283+
elif downsample_mode == "capped":
284+
for c in centroids_scaled:
285+
new_array[int(c[0]), int(c[1]), int(c[2])] = 1
286+
287+
elif downsample_mode == "components":
288+
if "component_labels" not in table.columns:
289+
raise KeyError("Dataframe must continue key 'component_labels' for downsampling with mode 'components'.")
290+
component_labels = list(table["component_labels"])
291+
for comp, centr in zip(component_labels, centroids_scaled):
292+
if comp != 0:
293+
new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp
294+
295+
else:
296+
raise ValueError("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'.")
297+
298+
new_array = np.round(new_array).astype(int)
299+
300+
return new_array
301+
302+
303+
def components_sgn(
304+
table: pd.DataFrame,
305+
keyword: str = "distance_nn100",
306+
threshold_erode: Optional[float] = None,
307+
postprocess_graph: bool = False,
308+
min_component_length: int = 50,
309+
min_edge_distance: float = 30,
310+
iterations_erode: Optional[int] = None,
311+
) -> List[List[int]]:
312+
"""Eroding the SGN segmentation.
313+
314+
Args:
315+
table: Dataframe of segmentation table.
316+
keyword: Keyword of the dataframe column for erosion.
317+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
318+
postprocess_graph: Post-process graph connected components by searching for near points.
319+
min_component_length: Minimal length for filtering out connected components.
320+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
321+
iterations_erode: Number of iterations for erosion, normally determined automatically.
322+
323+
Returns:
324+
Subgraph components as lists of label_ids of dataframe.
325+
"""
326+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
327+
labels = [int(i) for i in list(table["label_id"])]
328+
329+
distance_nn = list(table[keyword])
330+
distance_nn.sort()
331+
332+
if len(table) < 20000:
333+
iterations = iterations_erode if iterations_erode is not None else 0
334+
min_cells = None
335+
average_dist = int(distance_nn[int(len(table) * 0.8)])
336+
threshold = threshold_erode if threshold_erode is not None else average_dist
337+
else:
338+
iterations = iterations_erode if iterations_erode is not None else 15
339+
min_cells = 20000
340+
threshold = threshold_erode if threshold_erode is not None else 40
341+
342+
print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.")
343+
344+
new_subset = erode_subset(table.copy(), iterations=iterations,
345+
threshold=threshold, min_cells=min_cells, keyword=keyword)
346+
347+
# create graph from coordinates of eroded subset
348+
centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
349+
labels_subset = [int(i) for i in list(new_subset["label_id"])]
350+
coords = {}
351+
for index, element in zip(labels_subset, centroids_subset):
352+
coords[index] = element
353+
354+
graph = nx.Graph()
355+
for num, pos in coords.items():
356+
graph.add_node(num, pos=pos)
357+
358+
# create edges between points whose distance is less than threshold min_edge_distance
359+
for i in coords:
360+
for j in coords:
361+
if i < j:
362+
dist = math.dist(coords[i], coords[j])
363+
if dist <= min_edge_distance:
364+
graph.add_edge(i, j, weight=dist)
365+
366+
components = list(nx.connected_components(graph))
367+
368+
# remove connected components with less nodes than threshold min_component_length
369+
for component in components:
370+
if len(component) < min_component_length:
371+
for c in component:
372+
graph.remove_node(c)
373+
374+
components = [list(s) for s in nx.connected_components(graph)]
375+
376+
# add original coordinates closer to eroded component than threshold
377+
if postprocess_graph:
378+
threshold = 15
379+
for label_id, centr in zip(labels, centroids):
380+
if label_id not in labels_subset:
381+
add_coord = []
382+
for comp_index, component in enumerate(components):
383+
for comp_label in component:
384+
dist = math.dist(centr, centroids[comp_label - 1])
385+
if dist <= threshold:
386+
add_coord.append([comp_index, label_id])
387+
break
388+
if len(add_coord) != 0:
389+
components[add_coord[0][0]].append(add_coord[0][1])
390+
391+
return components
392+
393+
394+
def label_components(
395+
table: pd.DataFrame,
396+
min_size: int = 1000,
397+
threshold_erode: Optional[float] = None,
398+
min_component_length: int = 50,
399+
min_edge_distance: float = 30,
400+
iterations_erode: Optional[int] = None,
401+
) -> List[int]:
402+
"""Label components using graph connected components.
403+
404+
Args:
405+
table: Dataframe of segmentation table.
406+
min_size: Minimal number of pixels for filtering small instances.
407+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
408+
min_component_length: Minimal length for filtering out connected components.
409+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
410+
iterations_erode: Number of iterations for erosion, normally determined automatically.
411+
412+
Returns:
413+
List of component label for each point in dataframe. 0 - background, then in descending order of size
414+
"""
415+
416+
# First, apply the size filter.
417+
entries_filtered = table[table.n_pixels < min_size]
418+
table = table[table.n_pixels >= min_size]
419+
420+
components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
421+
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
422+
423+
# add size-filtered objects to have same initial length
424+
table = pd.concat([table, entries_filtered], ignore_index=True)
425+
table.sort_values("label_id")
426+
427+
length_components = [len(c) for c in components]
428+
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
429+
430+
component_labels = [0 for _ in range(len(table))]
431+
# be aware of 'label_id' of dataframe starting at 1
432+
for lab, comp in enumerate(components):
433+
for comp_index in comp:
434+
component_labels[comp_index - 1] = lab + 1
435+
436+
return component_labels
437+
438+
439+
def postprocess_sgn_seg(
440+
table: pd.DataFrame,
441+
min_size: int = 1000,
442+
threshold_erode: Optional[float] = None,
443+
min_component_length: int = 50,
444+
min_edge_distance: float = 30,
445+
iterations_erode: Optional[int] = None,
446+
) -> pd.DataFrame:
447+
"""Postprocessing SGN segmentation of cochlea.
448+
449+
Args:
450+
table: Dataframe of segmentation table.
451+
min_size: Minimal number of pixels for filtering small instances.
452+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
453+
min_component_length: Minimal length for filtering out connected components.
454+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
455+
iterations_erode: Number of iterations for erosion, normally determined automatically.
456+
457+
Returns:
458+
Dataframe with component labels.
459+
"""
460+
461+
comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode,
462+
min_component_length=min_component_length,
463+
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
464+
465+
table.loc[:, "component_labels"] = comp_labels
466+
467+
return table

scripts/prediction/expand_seg_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def main(
6161
neighbor_counts = [n[0] for n in neighbor_counts]
6262
tsv_table['neighbors_in_radius'+str(r_neighbor)] = neighbor_counts
6363

64-
tsv_table.to_csv(out_path, sep="\t")
64+
tsv_table.to_csv(out_path, sep="\t", index=False)
6565

6666

6767
if __name__ == "__main__":

0 commit comments

Comments
 (0)