Skip to content

Commit 1ceafb8

Browse files
committed
Postprocessing cochlea segmentation using erosion
1 parent 4af55c2 commit 1ceafb8

File tree

2 files changed

+240
-2
lines changed

2 files changed

+240
-2
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import multiprocessing as mp
2+
import os
23
from concurrent import futures
34
from typing import Callable, Tuple, Optional
45

@@ -8,8 +9,11 @@
89
import pandas as pd
910

1011
from elf.io import open_file
11-
from scipy.spatial import distance
12+
from scipy.ndimage import binary_fill_holes
13+
from scipy.ndimage import distance_transform_edt
14+
from scipy.ndimage import label
1215
from scipy.sparse import csr_matrix
16+
from scipy.spatial import distance
1317
from scipy.spatial import cKDTree, ConvexHull
1418
from skimage import measure
1519
from sklearn.neighbors import NearestNeighbors
@@ -205,3 +209,237 @@ def filter_chunk(block_id):
205209
)
206210

207211
return n_ids, n_ids_filtered
212+
213+
214+
# Postprocess segmentation by erosion using the above spatial statistics.
215+
# Currently implemented using downscaling and looking for connected components
216+
# TODO: Change implementation to graph connected components.
217+
218+
219+
def erode_subset(
220+
table: pd.DataFrame,
221+
iterations: Optional[int] = 1,
222+
min_cells: Optional[int] = None,
223+
threshold: Optional[int] = 35,
224+
keyword: Optional[str] = "distance_nn100",
225+
) -> pd.DataFrame:
226+
"""Erode coordinates of dataframe according to a keyword and a threshold.
227+
Use a copy of the dataframe as an input, if it should not be edited.
228+
229+
Args:
230+
table: Dataframe of segmentation table.
231+
iterations: Number of steps for erosion process.
232+
min_cells: Minimal number of rows. The erosion is stopped before reaching this number.
233+
threshold: Upper threshold for removing elements according to the given keyword.
234+
keyword: Keyword of dataframe for erosion.
235+
236+
Returns:
237+
The dataframe containing elements left after the erosion.
238+
"""
239+
print("initial length", len(table))
240+
n_neighbors = 100
241+
for i in range(iterations):
242+
table = table[table[keyword] < threshold]
243+
244+
# TODO: support other spatial statistics
245+
distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors)
246+
247+
if min_cells is not None and len(distance_avg) < min_cells:
248+
print(f"{i}-th iteration, length of subset {len(table)}, stopping erosion")
249+
break
250+
251+
table.loc[:, 'distance_nn'+str(n_neighbors)] = list(distance_avg)
252+
253+
print(f"{i}-th iteration, length of subset {len(table)}")
254+
255+
return table
256+
257+
258+
def downscaled_centroids(
259+
table: pd.DataFrame,
260+
scale_factor: int,
261+
ref_dimensions: Optional[Tuple[float,float,float]] = None,
262+
capped: Optional[bool] = True,
263+
) -> np.typing.NDArray:
264+
"""Downscale centroids in dataframe.
265+
266+
Args:
267+
table: Dataframe of segmentation table.
268+
scale_factor: Factor for downscaling coordinates.
269+
ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
270+
capped: Flag for capping output of array at 1 for the creation of a binary mask.
271+
272+
Returns:
273+
The downscaled array
274+
"""
275+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
276+
centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids]
277+
278+
if ref_dimensions is None:
279+
bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
280+
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions])
281+
new_array = np.zeros(bounding_dimensions_scaled)
282+
283+
else:
284+
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions])
285+
new_array = np.zeros(bounding_dimensions_scaled)
286+
287+
for c in centroids_scaled:
288+
new_array[int(c[0]), int(c[1]), int(c[2])] += 1
289+
290+
array_downscaled = np.round(new_array).astype(int)
291+
292+
if capped:
293+
array_downscaled[array_downscaled >= 1] = 1
294+
295+
return array_downscaled
296+
297+
298+
def coordinates_in_downscaled_blocks(
299+
table: pd.DataFrame,
300+
down_array: np.typing.NDArray,
301+
scale_factor: float,
302+
distance_component: Optional[int] = 0,
303+
) -> list:
304+
"""Checking if coordinates are within the downscaled array.
305+
306+
Args:
307+
table: Dataframe of segmentation table.
308+
down_array: Downscaled array.
309+
scale_factor: Factor which was used for downscaling.
310+
distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
311+
312+
Returns:
313+
A binary list representing whether the dataframe coordinates are within the array.
314+
"""
315+
# fill holes in down-sampled array
316+
down_array[down_array > 0] = 1
317+
down_array = binary_fill_holes(down_array).astype(np.uint8)
318+
319+
# check if input coordinates are within down-sampled blocks
320+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
321+
centroids_scaled = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids]
322+
323+
distance_map = distance_transform_edt(down_array == 0)
324+
325+
centroids_binary = []
326+
for c in centroids_scaled:
327+
coord = (int(c[0]), int(c[1]), int(c[2]))
328+
if down_array[coord] != 0:
329+
centroids_binary.append(1)
330+
elif distance_map[coord] <= distance_component:
331+
centroids_binary.append(1)
332+
else:
333+
centroids_binary.append(0)
334+
335+
return centroids_binary
336+
337+
338+
def erode_sgn_seg(
339+
table: pd.DataFrame,
340+
keyword: Optional[str] = "distance_nn100",
341+
filter_small_components: Optional[int] = None,
342+
scale_factor: Optional[float] = 20,
343+
threshold_erode: Optional[float] = None,
344+
) -> Tuple[pd.DataFrame,np.typing.NDArray,np.typing.NDArray,np.typing.NDArray]:
345+
"""Eroding the SGN segmentation.
346+
347+
Args:
348+
table: Dataframe of segmentation table.
349+
keyword: Keyword of the dataframe column for erosion.
350+
filter_small_components: Filter components smaller after n blocks after labeling.
351+
scale_factor: Scaling for downsampling.
352+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
353+
354+
Returns:
355+
The labeled components of the downscaled, eroded coordinates.
356+
The larget connected component of the labeled components.
357+
"""
358+
359+
ref_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
360+
print("initial length", len(table))
361+
distance_nn = list(table[keyword])
362+
distance_nn.sort()
363+
364+
if len(table) < 20000:
365+
iterations = 1
366+
min_cells = None
367+
average_dist = int(distance_nn[int(len(table) * 0.8)])
368+
threshold = threshold_erode if threshold_erode is not None else average_dist
369+
else:
370+
iterations = 15
371+
min_cells = 20000
372+
threshold = threshold_erode if threshold_erode is not None else 40
373+
374+
print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.")
375+
376+
new_subset = erode_subset(table.copy(), iterations=iterations,
377+
threshold=threshold, min_cells=min_cells, keyword=keyword)
378+
eroded_arr = downscaled_centroids(new_subset, scale_factor=scale_factor, ref_dimensions=ref_dimensions)
379+
# Label connected components
380+
labeled, num_features = label(eroded_arr)
381+
382+
# Find the largest component
383+
sizes = [(labeled == i).sum() for i in range(1, num_features + 1)]
384+
largest_label = np.argmax(sizes) + 1
385+
386+
# Extract only the largest component
387+
largest_component = (labeled == largest_label).astype(np.uint8)
388+
largest_component_filtered = binary_fill_holes(largest_component).astype(np.uint8)
389+
390+
#filter small sizes
391+
if filter_small_components is not None:
392+
for (size, feature) in zip(sizes, range(1, num_features + 1)):
393+
if size < filter_small_components:
394+
labeled[labeled == feature] = 0
395+
396+
return labeled, largest_component_filtered
397+
398+
399+
def get_components(table: pd.DataFrame,
400+
labeled: np.typing.NDArray,
401+
scale_factor: float,
402+
distance_component: Optional[int] = 0,
403+
) -> list:
404+
"""Indexing coordinates according to labeled array.
405+
406+
Args:
407+
table: Dataframe of segmentation table.
408+
labeled: Array containing differently labeled components.
409+
scale_factor: Scaling for downsampling.
410+
distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
411+
412+
Returns:
413+
List of component labels.
414+
"""
415+
unique_labels = list(np.unique(labeled))
416+
component_labels = [0 for _ in range(len(table))]
417+
for label_index, l in enumerate(unique_labels):
418+
if l != 0:
419+
label_arr = (labeled == l).astype(np.uint8)
420+
centroids_binary = coordinates_in_downscaled_blocks(table, label_arr,
421+
scale_factor, distance_component = distance_component)
422+
for num, c in enumerate(centroids_binary):
423+
if c != 0:
424+
component_labels[num] = label_index
425+
return component_labels
426+
427+
428+
def postprocess_sgn_seg(table: pd.DataFrame, scale_factor: Optional[float] = 20) -> pd.DataFrame:
429+
"""Postprocessing SGN segmentation of cochlea.
430+
431+
Args:
432+
table: Dataframe of segmentation table.
433+
scale_factor: Scaling for downsampling.
434+
435+
Returns:
436+
Dataframe with component labels.
437+
"""
438+
labeled, largest_component = erode_sgn_seg(table, filter_small_labels=10,
439+
scale_factor=scale_factor, threshold_erode=None)
440+
441+
component_labels = get_components(table, labeled, scale_factor, distance_component = 1)
442+
443+
table.loc[:, "component_labels"] = component_labels
444+
445+
return table

scripts/prediction/expand_seg_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def main(
6161
neighbor_counts = [n[0] for n in neighbor_counts]
6262
tsv_table['neighbors_in_radius'+str(r_neighbor)] = neighbor_counts
6363

64-
tsv_table.to_csv(out_path, sep="\t")
64+
tsv_table.to_csv(out_path, sep="\t", index=False)
6565

6666

6767
if __name__ == "__main__":

0 commit comments

Comments
 (0)