Skip to content

Commit d9ae457

Browse files
Update postprocessing code
1 parent 66670cd commit d9ae457

File tree

5 files changed

+156
-82
lines changed

5 files changed

+156
-82
lines changed

flamingo_tools/file_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]:
6464

6565
# TODO: Update the any types:
6666
# The first should be the type of a zarr s3 store,
67-
def read_image_data(input_path: Union[str, Any], input_key: Optional[str]) -> np.array_like:
67+
def read_image_data(input_path: Union[str, Any], input_key: Optional[str]) -> np.typing.ArrayLike:
6868
"""Read flamingo image data, stored in various formats.
6969
7070
Args:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .unet_prediction import run_unet_prediction
2-
from .postprocessing import filter_isolated_objects
2+
from .postprocessing import filter_segmentation
Lines changed: 149 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,174 @@
1-
import numpy as np
2-
import vigra
31
import multiprocessing as mp
42
from concurrent import futures
3+
from typing import Callable, Tuple, Optional
54

6-
from skimage import measure
5+
import elf.parallel as parallel
6+
import numpy as np
7+
import nifty.tools as nt
8+
import pandas as pd
9+
import vigra
10+
11+
from elf.io import open_file
712
from scipy.spatial import distance
813
from scipy.sparse import csr_matrix
9-
from tqdm import tqdm
14+
from scipy.spatial import cKDTree, ConvexHull
15+
from skimage import measure
1016
from sklearn.neighbors import NearestNeighbors
17+
from tqdm import tqdm
1118

12-
import elf.parallel as parallel
13-
from elf.io import open_file
14-
import nifty.tools as nt
1519

20+
#
21+
# Spatial statistics:
22+
# Three different spatial statistics implementations that
23+
# can be used as the basis of a filtering criterion.
24+
#
1625

17-
def distance_nearest_neighbors(tsv_table, n_neighbors=10, expand_table=True):
18-
"""Calculate average distance of n nearest neighbors.
1926

20-
:param DataFrame tsv_table:
21-
:param int n_neighbors: Number of nearest neighbors
22-
:param bool expand_table: Flag for expanding DataFrame
23-
:returns: List of average distances
24-
:rtype: list
25-
"""
26-
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))
27+
def nearest_neighbor_distance(table: pd.DataFrame, n_neighbors: int = 10) -> np.ndarray:
28+
"""Compute the average distance to the n nearest neighbors.
29+
30+
Args:
31+
table: The table with the centroid coordinates.
32+
n_neighbors: The number of neighbors to take into account for the distance computation.
2733
28-
coordinates = np.array(centroids)
34+
Returns:
35+
The average distances to the n nearest neighbors.
36+
"""
37+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
38+
centroids = np.array(centroids)
2939

30-
# nearest neighbor is always itself, so n_neighbors+=1
31-
nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(coordinates)
32-
distances, indices = nbrs.kneighbors(coordinates)
40+
# Nearest neighbor is always itself, so n_neighbors+=1.
41+
nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(centroids)
42+
distances, indices = nbrs.kneighbors(centroids)
3343

3444
# Average distance to nearest neighbors
35-
distance_avg = [sum(d) / len(d) for d in distances[:, 1:]]
45+
distance_avg = np.array([sum(d) / len(d) for d in distances[:, 1:]])
46+
return distance_avg
3647

37-
if expand_table:
38-
tsv_table['distance_nn'+str(n_neighbors)] = distance_avg
3948

40-
return distance_avg
49+
def local_ripleys_k(table: pd.DataFrame, radius: float = 15, volume: Optional[float] = None) -> np.ndarray:
50+
"""Compute the local Ripley's K function for each point in a 2D / 3D.
4151
52+
Args:
53+
table: The table with the centroid coordinates.
54+
radius: The radius within which to count neighboring points.
55+
volume: The area (2D) or volume (3D) of the study region. If None, it is estimated from the convex hull.
56+
57+
Returns:
58+
An array containing the local K values for each point.
59+
"""
60+
points = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
61+
points = np.array(points)
62+
n_points, dim = points.shape
63+
64+
if dim not in (2, 3):
65+
raise ValueError("Points array must be of shape (n_points, 2) or (n_points, 3).")
66+
67+
# Estimate area/volume if not provided.
68+
if volume is None:
69+
hull = ConvexHull(points)
70+
volume = hull.volume # For 2D, 'volume' is area; for 3D, it's volume.
71+
72+
# Compute point density.
73+
density = n_points / volume
74+
75+
# Build a KD-tree for efficient neighbor search.
76+
tree = cKDTree(points)
77+
78+
# Count neighbors within the specified radius for each point
79+
counts = tree.query_ball_point(points, r=radius)
80+
local_counts = np.array([len(c) - 1 for c in counts]) # Exclude the point itself
81+
82+
# Normalize by density to get local K values
83+
local_k = local_counts / density
84+
return local_k
4285

43-
def filter_isolated_objects(
44-
segmentation, output_path, tsv_table=None,
45-
distance_threshold=15, neighbor_threshold=5, min_size=1000,
46-
output_key="segmentation_postprocessed",
47-
):
48-
"""Postprocessing step to filter isolated objects from a segmentation.
4986

50-
Instance segmentations are filtered if they have fewer neighbors
51-
than a given threshold in a given distance around them.
52-
Additionally, size filtering is possible if a TSV file is supplied.
53-
54-
:param dataset segmentation: Dataset containing the segmentation
55-
:param str out_path: Output path for postprocessed segmentation
56-
:param str tsv_file: Optional TSV file containing segmentation parameters in MoBIE format
57-
:param int distance_threshold: Distance in micrometer to check for neighbors
58-
:param int neighbor_threshold: Minimal number of neighbors for filtering
59-
:param int min_size: Minimal number of pixels for filtering small instances
60-
:param str output_key: Output key for postprocessed segmentation
87+
def neighbors_in_radius(table: pd.DataFrame, radius: float = 15) -> np.ndarray:
88+
"""Compute the number of neighbors within a given radius.
89+
90+
Args:
91+
table: The table with the centroid coordinates.
92+
radius: The radius within which to count neighboring points.
93+
94+
Returns:
95+
An array containing the number of neighbors within the given radius.
6196
"""
62-
if tsv_table is not None:
63-
n_pixels = tsv_table["n_pixels"].to_list()
64-
label_ids = tsv_table["label_id"].to_list()
65-
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))
66-
n_ids = len(label_ids)
67-
68-
# filter out cells smaller than min_size
69-
if min_size is not None:
70-
min_size_label_ids = [l for (l, n) in zip(label_ids, n_pixels) if n <= min_size]
71-
centroids = [c for (c, l) in zip(centroids, label_ids) if l not in min_size_label_ids]
72-
label_ids = [int(lid) for lid in label_ids if lid not in min_size_label_ids]
73-
74-
coordinates = np.array(centroids)
75-
label_ids = np.array(label_ids)
76-
77-
else:
78-
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True)
79-
props = measure.regionprops(segmentation)
80-
coordinates = np.array([prop.centroid for prop in props])
81-
label_ids = np.unique(segmentation)[1:]
82-
83-
# Calculate pairwise distances and convert to a square matrix
84-
dist_matrix = distance.pdist(coordinates)
97+
points = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
98+
points = np.array(points)
99+
100+
dist_matrix = distance.pdist(points)
85101
dist_matrix = distance.squareform(dist_matrix)
86102

87-
# Create sparse matrix of connections within the threshold distance
88-
sparse_matrix = csr_matrix(dist_matrix < distance_threshold, dtype=int)
103+
# Create sparse matrix of connections within the threshold distance.
104+
sparse_matrix = csr_matrix(dist_matrix < radius, dtype=int)
89105

90-
# Sum each row to count neighbors
106+
# Sum each row to count neighbors.
91107
neighbor_counts = sparse_matrix.sum(axis=1)
108+
return np.array(neighbor_counts)
109+
110+
111+
#
112+
# Filtering function:
113+
# Filter the segmentation based on a spatial statistics from above.
114+
#
115+
116+
117+
def _compute_table(segmentation):
118+
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True)
119+
props = measure.regionprops(segmentation)
120+
coordinates = np.array([prop.centroid for prop in props])[1:]
121+
label_ids = np.unique(segmentation)[1:]
122+
sizes = np.array([prop.area for prop in props])[1:]
123+
table = pd.DataFrame({
124+
"label_id": label_ids,
125+
"n_pixels": sizes,
126+
"anchor_x": coordinates[:, 2],
127+
"anchor_y": coordinates[:, 1],
128+
"anchor_z": coordinates[:, 0],
129+
})
130+
return table
131+
132+
133+
def filter_segmentation(
134+
segmentation: np.typing.ArrayLike,
135+
output_path: str,
136+
spatial_statistics: Callable,
137+
threshold: float,
138+
min_size: int = 1000,
139+
table: Optional[pd.DataFrame] = None,
140+
output_key: str = "segmentation_postprocessed",
141+
) -> Tuple[int, int]:
142+
"""Postprocessing step to filter isolated objects from a segmentation.
92143
93-
filter_mask = np.array(neighbor_counts < neighbor_threshold).squeeze()
94-
filter_ids = label_ids[filter_mask]
144+
Instance segmentations are filtered based on spatial statistics and a threshold.
145+
In addition, objects smaller than a given size are filtered out.
146+
147+
Args:
148+
segmentation: Dataset containing the segmentation
149+
output_path: Output path for postprocessed segmentation
150+
spatial_statistics:
151+
threshold: Distance in micrometer to check for neighbors
152+
min_size: Minimal number of pixels for filtering small instances
153+
table:
154+
output_key: Output key for postprocessed segmentation
155+
156+
Returns:
157+
n_ids
158+
n_ids_filtered
159+
"""
160+
# Compute the table on the fly.
161+
# NOTE: this currently doesn't work for large segmentations.
162+
if table is None:
163+
table = _compute_table(segmentation)
164+
n_ids = len(table)
165+
166+
# First apply the size filter.
167+
table = table[table.n_pixels > min_size]
168+
stat_values = spatial_statistics(table)
169+
170+
keep_mask = np.array(stat_values > threshold).squeeze()
171+
keep_ids = table.label_id.values[keep_mask]
95172

96173
shape = segmentation.shape
97174
block_shape = (128, 128, 128)
@@ -100,7 +177,6 @@ def filter_isolated_objects(
100177
blocking = nt.blocking([0] * len(shape), shape, block_shape)
101178

102179
output = open_file(output_path, mode="a")
103-
104180
output_dataset = output.create_dataset(
105181
output_key, shape=shape, dtype=segmentation.dtype,
106182
chunks=chunks, compression="gzip"
@@ -112,17 +188,16 @@ def filter_chunk(block_id):
112188
block = blocking.getBlock(block_id)
113189
volume_index = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
114190
data = segmentation[volume_index]
115-
data[np.isin(data, filter_ids)] = 0
191+
data[np.isin(data, keep_ids)] = 0
116192
output_dataset[volume_index] = data
117193

118194
# Limit the number of cores for parallelization.
119195
n_threads = min(16, mp.cpu_count())
120-
121196
with futures.ThreadPoolExecutor(n_threads) as filter_pool:
122197
list(tqdm(filter_pool.map(filter_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks))
123198

124199
seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive(
125-
output_dataset, start_label=1, keep_zeros=True, block_shape=(128, 128, 128)
200+
output_dataset, start_label=1, keep_zeros=True, block_shape=block_shape
126201
)
127202

128-
return seg_filtered, n_ids, n_ids_filtered
203+
return n_ids, n_ids_filtered

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from tqdm import tqdm
2626

2727
import flamingo_tools.s3_utils as s3_utils
28-
from flamingl_tools.file_utils import read_image_data
28+
from flamingo_tools.file_utils import read_image_data
2929

3030

3131
class SelectChannel(SimpleTransformationWrapper):
@@ -35,7 +35,7 @@ class SelectChannel(SimpleTransformationWrapper):
3535
volume: The array-like input dataset.
3636
channel: The channel that will be selected.
3737
"""
38-
def __init__(self, volume: np.array_like, channel: int):
38+
def __init__(self, volume: np.typing.ArrayLike, channel: int):
3939
self.channel = channel
4040
super().__init__(volume, lambda x: x[self.channel], with_channels=True)
4141

scripts/prediction/postprocess_seg.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import argparse
22
import os
3-
import sys
43

54
import pandas as pd
65
import zarr
76

8-
sys.path.append("../..")
9-
107
import flamingo_tools.s3_utils as s3_utils
8+
from flamingo_tools.segmentation import filter_segmentation
9+
1110

11+
# TODO needs updates
1212
def main():
13-
from flamingo_tools.segmentation import filter_isolated_objects
1413

1514
parser = argparse.ArgumentParser(
1615
description="Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket.")

0 commit comments

Comments
 (0)