Skip to content

Commit 4cf9e41

Browse files
Merge branch 'master' into measurements
2 parents 33aae70 + 68fba2f commit 4cf9e41

File tree

6 files changed

+362
-49
lines changed

6 files changed

+362
-49
lines changed

environment.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ dependencies:
1414
- torch_em
1515
- trimesh
1616
- z5py
17-
- zarr
17+
# Don't install zarr v3, as we are not sure that it is compatible with MoBIE etc. yet
18+
- zarr <3

flamingo_tools/file_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import zarr
88
from elf.io import open_file
99

10+
try:
11+
from zarr.abc.store import Store
12+
except ImportError:
13+
from zarr._storage.store import BaseStore as Store
14+
1015

1116
def _parse_shape(metadata_file):
1217
depth, height, width = None, None, None
@@ -62,7 +67,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]:
6267
return x
6368

6469

65-
def read_image_data(input_path: Union[str, zarr.storage.FSStore], input_key: Optional[str]) -> np.typing.ArrayLike:
70+
def read_image_data(input_path: Union[str, Store], input_key: Optional[str]) -> np.typing.ArrayLike:
6671
"""Read flamingo image data, stored in various formats.
6772
6873
Args:

flamingo_tools/s3_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import s3fs
88
import zarr
99

10+
try:
11+
from zarr.abc.store import Store
12+
except ImportError:
13+
from zarr._storage.store import BaseStore as Store
14+
1015

1116
# Dedicated bucket for cochlea lightsheet project
1217
MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet"
@@ -93,7 +98,7 @@ def get_s3_path(
9398
bucket_name: Optional[str] = None,
9499
service_endpoint: Optional[str] = None,
95100
credential_file: Optional[str] = None,
96-
) -> Tuple[zarr.storage.FSStore, s3fs.core.S3FileSystem]:
101+
) -> Tuple[Store, s3fs.core.S3FileSystem]:
97102
"""Get S3 path for a file or folder and file system based on S3 parameters and credentials.
98103
99104
Args:

flamingo_tools/segmentation/postprocessing.py

Lines changed: 262 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
import math
12
import multiprocessing as mp
23
from concurrent import futures
3-
from typing import Callable, Tuple, Optional
4+
from typing import Callable, List, Optional, Tuple
45

56
import elf.parallel as parallel
67
import numpy as np
78
import nifty.tools as nt
9+
import networkx as nx
810
import pandas as pd
911

1012
from elf.io import open_file
11-
from scipy.spatial import distance
1213
from scipy.sparse import csr_matrix
14+
from scipy.spatial import distance
1315
from scipy.spatial import cKDTree, ConvexHull
1416
from skimage import measure
1517
from sklearn.neighbors import NearestNeighbors
@@ -213,3 +215,261 @@ def filter_chunk(block_id):
213215
)
214216

215217
return n_ids, n_ids_filtered
218+
219+
220+
def erode_subset(
221+
table: pd.DataFrame,
222+
iterations: int = 1,
223+
min_cells: Optional[int] = None,
224+
threshold: int = 35,
225+
keyword: str = "distance_nn100",
226+
) -> pd.DataFrame:
227+
"""Erode coordinates of dataframe according to a keyword and a threshold.
228+
Use a copy of the dataframe as an input, if it should not be edited.
229+
230+
Args:
231+
table: Dataframe of segmentation table.
232+
iterations: Number of steps for erosion process.
233+
min_cells: Minimal number of rows. The erosion is stopped after falling below this limit.
234+
threshold: Upper threshold for removing elements according to the given keyword.
235+
keyword: Keyword of dataframe for erosion.
236+
237+
Returns:
238+
The dataframe containing elements left after the erosion.
239+
"""
240+
print("initial length", len(table))
241+
n_neighbors = 100
242+
for i in range(iterations):
243+
table = table[table[keyword] < threshold]
244+
245+
distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors)
246+
247+
if min_cells is not None and len(distance_avg) < min_cells:
248+
print(f"{i}-th iteration, length of subset {len(table)}, stopping erosion")
249+
break
250+
251+
table.loc[:, 'distance_nn'+str(n_neighbors)] = list(distance_avg)
252+
253+
print(f"{i}-th iteration, length of subset {len(table)}")
254+
255+
return table
256+
257+
258+
def downscaled_centroids(
259+
table: pd.DataFrame,
260+
scale_factor: int,
261+
ref_dimensions: Optional[Tuple[float, float, float]] = None,
262+
downsample_mode: str = "accumulated",
263+
) -> np.typing.NDArray:
264+
"""Downscale centroids in dataframe.
265+
266+
Args:
267+
table: Dataframe of segmentation table.
268+
scale_factor: Factor for downscaling coordinates.
269+
ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
270+
downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'.
271+
272+
Returns:
273+
The downscaled array
274+
"""
275+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
276+
centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids]
277+
278+
if ref_dimensions is None:
279+
bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
280+
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions])
281+
new_array = np.zeros(bounding_dimensions_scaled)
282+
283+
else:
284+
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions])
285+
new_array = np.zeros(bounding_dimensions_scaled)
286+
287+
if downsample_mode == "accumulated":
288+
for c in centroids_scaled:
289+
new_array[int(c[0]), int(c[1]), int(c[2])] += 1
290+
291+
elif downsample_mode == "capped":
292+
for c in centroids_scaled:
293+
new_array[int(c[0]), int(c[1]), int(c[2])] = 1
294+
295+
elif downsample_mode == "components":
296+
if "component_labels" not in table.columns:
297+
raise KeyError("Dataframe must continue key 'component_labels' for downsampling with mode 'components'.")
298+
component_labels = list(table["component_labels"])
299+
for comp, centr in zip(component_labels, centroids_scaled):
300+
if comp != 0:
301+
new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp
302+
303+
else:
304+
raise ValueError("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'.")
305+
306+
new_array = np.round(new_array).astype(int)
307+
308+
return new_array
309+
310+
311+
def components_sgn(
312+
table: pd.DataFrame,
313+
keyword: str = "distance_nn100",
314+
threshold_erode: Optional[float] = None,
315+
postprocess_graph: bool = False,
316+
min_component_length: int = 50,
317+
min_edge_distance: float = 30,
318+
iterations_erode: Optional[int] = None,
319+
) -> List[List[int]]:
320+
"""Eroding the SGN segmentation.
321+
322+
Args:
323+
table: Dataframe of segmentation table.
324+
keyword: Keyword of the dataframe column for erosion.
325+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
326+
postprocess_graph: Post-process graph connected components by searching for near points.
327+
min_component_length: Minimal length for filtering out connected components.
328+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
329+
iterations_erode: Number of iterations for erosion, normally determined automatically.
330+
331+
Returns:
332+
Subgraph components as lists of label_ids of dataframe.
333+
"""
334+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
335+
labels = [int(i) for i in list(table["label_id"])]
336+
337+
distance_nn = list(table[keyword])
338+
distance_nn.sort()
339+
340+
if len(table) < 20000:
341+
iterations = iterations_erode if iterations_erode is not None else 0
342+
min_cells = None
343+
average_dist = int(distance_nn[int(len(table) * 0.8)])
344+
threshold = threshold_erode if threshold_erode is not None else average_dist
345+
else:
346+
iterations = iterations_erode if iterations_erode is not None else 15
347+
min_cells = 20000
348+
threshold = threshold_erode if threshold_erode is not None else 40
349+
350+
print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.")
351+
352+
new_subset = erode_subset(table.copy(), iterations=iterations,
353+
threshold=threshold, min_cells=min_cells, keyword=keyword)
354+
355+
# create graph from coordinates of eroded subset
356+
centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
357+
labels_subset = [int(i) for i in list(new_subset["label_id"])]
358+
coords = {}
359+
for index, element in zip(labels_subset, centroids_subset):
360+
coords[index] = element
361+
362+
graph = nx.Graph()
363+
for num, pos in coords.items():
364+
graph.add_node(num, pos=pos)
365+
366+
# create edges between points whose distance is less than threshold min_edge_distance
367+
for i in coords:
368+
for j in coords:
369+
if i < j:
370+
dist = math.dist(coords[i], coords[j])
371+
if dist <= min_edge_distance:
372+
graph.add_edge(i, j, weight=dist)
373+
374+
components = list(nx.connected_components(graph))
375+
376+
# remove connected components with less nodes than threshold min_component_length
377+
for component in components:
378+
if len(component) < min_component_length:
379+
for c in component:
380+
graph.remove_node(c)
381+
382+
components = [list(s) for s in nx.connected_components(graph)]
383+
384+
# add original coordinates closer to eroded component than threshold
385+
if postprocess_graph:
386+
threshold = 15
387+
for label_id, centr in zip(labels, centroids):
388+
if label_id not in labels_subset:
389+
add_coord = []
390+
for comp_index, component in enumerate(components):
391+
for comp_label in component:
392+
dist = math.dist(centr, centroids[comp_label - 1])
393+
if dist <= threshold:
394+
add_coord.append([comp_index, label_id])
395+
break
396+
if len(add_coord) != 0:
397+
components[add_coord[0][0]].append(add_coord[0][1])
398+
399+
return components
400+
401+
402+
def label_components(
403+
table: pd.DataFrame,
404+
min_size: int = 1000,
405+
threshold_erode: Optional[float] = None,
406+
min_component_length: int = 50,
407+
min_edge_distance: float = 30,
408+
iterations_erode: Optional[int] = None,
409+
) -> List[int]:
410+
"""Label components using graph connected components.
411+
412+
Args:
413+
table: Dataframe of segmentation table.
414+
min_size: Minimal number of pixels for filtering small instances.
415+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
416+
min_component_length: Minimal length for filtering out connected components.
417+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
418+
iterations_erode: Number of iterations for erosion, normally determined automatically.
419+
420+
Returns:
421+
List of component label for each point in dataframe. 0 - background, then in descending order of size
422+
"""
423+
424+
# First, apply the size filter.
425+
entries_filtered = table[table.n_pixels < min_size]
426+
table = table[table.n_pixels >= min_size]
427+
428+
components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
429+
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
430+
431+
# add size-filtered objects to have same initial length
432+
table = pd.concat([table, entries_filtered], ignore_index=True)
433+
table.sort_values("label_id")
434+
435+
length_components = [len(c) for c in components]
436+
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
437+
438+
component_labels = [0 for _ in range(len(table))]
439+
# be aware of 'label_id' of dataframe starting at 1
440+
for lab, comp in enumerate(components):
441+
for comp_index in comp:
442+
component_labels[comp_index - 1] = lab + 1
443+
444+
return component_labels
445+
446+
447+
def postprocess_sgn_seg(
448+
table: pd.DataFrame,
449+
min_size: int = 1000,
450+
threshold_erode: Optional[float] = None,
451+
min_component_length: int = 50,
452+
min_edge_distance: float = 30,
453+
iterations_erode: Optional[int] = None,
454+
) -> pd.DataFrame:
455+
"""Postprocessing SGN segmentation of cochlea.
456+
457+
Args:
458+
table: Dataframe of segmentation table.
459+
min_size: Minimal number of pixels for filtering small instances.
460+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
461+
min_component_length: Minimal length for filtering out connected components.
462+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
463+
iterations_erode: Number of iterations for erosion, normally determined automatically.
464+
465+
Returns:
466+
Dataframe with component labels.
467+
"""
468+
469+
comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode,
470+
min_component_length=min_component_length,
471+
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
472+
473+
table.loc[:, "component_labels"] = comp_labels
474+
475+
return table

scripts/prediction/expand_seg_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def main(
6161
neighbor_counts = [n[0] for n in neighbor_counts]
6262
tsv_table['neighbors_in_radius'+str(r_neighbor)] = neighbor_counts
6363

64-
tsv_table.to_csv(out_path, sep="\t")
64+
tsv_table.to_csv(out_path, sep="\t", index=False)
6565

6666

6767
if __name__ == "__main__":

0 commit comments

Comments
 (0)