Skip to content

Commit bf58071

Browse files
committed
Updated script
1 parent 67193f5 commit bf58071

File tree

2 files changed

+113
-48
lines changed

2 files changed

+113
-48
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def erode_subset(
222222
Args:
223223
table: Dataframe of segmentation table.
224224
iterations: Number of steps for erosion process.
225-
min_cells: Minimal number of rows. The erosion is stopped before reaching this number.
225+
min_cells: Minimal number of rows. The erosion is stopped after falling below this limit.
226226
threshold: Upper threshold for removing elements according to the given keyword.
227227
keyword: Keyword of dataframe for erosion.
228228
@@ -259,7 +259,7 @@ def downscaled_centroids(
259259
table: Dataframe of segmentation table.
260260
scale_factor: Factor for downscaling coordinates.
261261
ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
262-
downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'
262+
downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'.
263263
264264
Returns:
265265
The downscaled array
@@ -326,7 +326,6 @@ def components_sgn(
326326
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
327327
labels = [int(i) for i in list(table["label_id"])]
328328

329-
print("initial length", len(table))
330329
distance_nn = list(table[keyword])
331330
distance_nn.sort()
332331

@@ -394,6 +393,7 @@ def components_sgn(
394393

395394
def label_components(
396395
table: pd.DataFrame,
396+
min_size: Optional[int] = 1000,
397397
threshold_erode: Optional[float] = None,
398398
min_component_length: Optional[int] = 50,
399399
min_edge_distance: Optional[float] = 30,
@@ -403,6 +403,7 @@ def label_components(
403403
404404
Args:
405405
table: Dataframe of segmentation table.
406+
min_size: Minimal number of pixels for filtering small instances.
406407
threshold_erode: Threshold of column value after erosion step with spatial statistics.
407408
min_component_length: Minimal length for filtering out connected components.
408409
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
@@ -411,9 +412,18 @@ def label_components(
411412
Returns:
412413
List of component label for each point in dataframe. 0 - background, then in descending order of size
413414
"""
415+
416+
# First, apply the size filter.
417+
entries_filtered = table[table.n_pixels < min_size]
418+
table = table[table.n_pixels >= min_size]
419+
414420
components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
415421
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
416422

423+
# add size-filtered objects to have same initial length
424+
table = pd.concat([table, entries_filtered], ignore_index=True)
425+
table.sort_values("label_id")
426+
417427
length_components = [len(c) for c in components]
418428
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
419429

@@ -428,17 +438,30 @@ def label_components(
428438

429439
def postprocess_sgn_seg(
430440
table: pd.DataFrame,
441+
min_size: Optional[int] = 1000,
442+
threshold_erode: Optional[float] = None,
443+
min_component_length: Optional[int] = 50,
444+
min_edge_distance: Optional[float] = 30,
445+
iterations_erode: Optional[int] = None,
431446
) -> pd.DataFrame:
432447
"""Postprocessing SGN segmentation of cochlea.
433448
434449
Args:
435450
table: Dataframe of segmentation table.
451+
min_size: Minimal number of pixels for filtering small instances.
452+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
453+
min_component_length: Minimal length for filtering out connected components.
454+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
455+
iterations_erode: Number of iterations for erosion, normally determined automatically.
436456
437457
Returns:
438458
Dataframe with component labels.
439459
"""
440-
component_labels = label_components(table)
441460

442-
table.loc[:, "component_labels"] = component_labels
461+
comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode,
462+
min_component_length=min_component_length,
463+
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
464+
465+
table.loc[:, "component_labels"] = comp_labels
443466

444467
return table

scripts/prediction/postprocess_seg.py

Lines changed: 85 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import flamingo_tools.s3_utils as s3_utils
88
from flamingo_tools.segmentation import filter_segmentation
99
from flamingo_tools.segmentation.postprocessing import nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius
10+
from flamingo_tools.segmentation.postprocessing import postprocess_sgn_seg
1011

1112

1213
# TODO needs updates
@@ -15,18 +16,34 @@ def main():
1516
parser = argparse.ArgumentParser(
1617
description="Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket.")
1718

18-
parser.add_argument("-o", "--output_folder", type=str, required=True)
19+
parser.add_argument("-o", "--output_folder", type=str, default=None)
1920

2021
parser.add_argument("-t", "--tsv", type=str, default=None,
2122
help="TSV-file in MoBIE format which contains information about segmentation.")
23+
parser.add_argument("--tsv_out", type=str, default=None,
24+
help="File path to save post-processed dataframe. Default: default.tsv")
25+
2226
parser.add_argument('-k', "--input_key", type=str, default="segmentation",
2327
help="The key / internal path of the segmentation.")
2428
parser.add_argument("--output_key", type=str, default="segmentation_postprocessed",
2529
help="The key / internal path of the output.")
2630
parser.add_argument('-r', "--resolution", type=float, default=0.38,
2731
help="Resolution of segmentation in micrometer.")
2832

29-
parser.add_argument("--s3_input", type=str, default=None, help="Input file path on S3 bucket.")
33+
# options for post-processing
34+
parser.add_argument("--min_size", type=int, default=1000,
35+
help="Minimal number of pixels for filtering small instances.")
36+
parser.add_argument("--threshold", type=float, default=None,
37+
help="Threshold for spatial statistics.")
38+
parser.add_argument("--min_component_length", type=int, default=50,
39+
help="Minimal length for filtering out connected components.")
40+
parser.add_argument("--min_edge_dist", type=float, default=30,
41+
help="Minimal distance in micrometer between points to create edges for connected components.")
42+
parser.add_argument("--iterations_erode", type=int, default=None,
43+
help="Number of iterations for erosion, normally determined automatically.")
44+
45+
# options for S3 bucket
46+
parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.")
3047
parser.add_argument("--s3_credentials", type=str, default=None,
3148
help="Input file containing S3 credentials. "
3249
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
@@ -35,23 +52,42 @@ def main():
3552
parser.add_argument("--s3_service_endpoint", type=str, default=None,
3653
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")
3754

38-
parser.add_argument("--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object")
39-
55+
# options for spatial statistics
4056
parser.add_argument("--n_neighbors", type=int, default=None,
4157
help="Value for calculating distance to 'n' nearest neighbors.")
42-
4358
parser.add_argument("--local_ripley_radius", type=int, default=None,
4459
help="Value for radius for calculating local Ripley's K function.")
45-
4660
parser.add_argument("--r_neighbors", type=int, default=None,
4761
help="Value for radius for calculating number of neighbors in range.")
4862

4963
args = parser.parse_args()
5064

65+
if args.output_folder is None and args.tsv is None:
66+
raise ValueError("Either supply an output folder containing 'segmentation.zarr' or a TSV-file in MoBIE format.")
67+
68+
# check output folder
69+
if args.output_folder is not None:
70+
seg_path = os.path.join(args.output_folder, "segmentation.zarr")
71+
if args.s3:
72+
s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name,
73+
service_endpoint=args.s3_service_endpoint,
74+
credential_file=args.s3_credentials)
75+
with zarr.open(s3_path, mode="r") as f:
76+
segmentation = f[args.input_key]
77+
else:
78+
with zarr.open(seg_path, mode="r") as f:
79+
segmentation = f[args.input_key]
80+
else:
81+
seg_path = None
82+
83+
# check input for spatial statistics
5184
postprocess_functions = [nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius]
5285
function_keywords = ["n_neighbors", "radius", "radius"]
5386
postprocess_options = [args.n_neighbors, args.local_ripley_radius, args.r_neighbors]
54-
default_thresholds = [15, 20, 20]
87+
default_thresholds = [args.threshold for _ in postprocess_functions]
88+
89+
if seg_path is not None and args.threshold is None:
90+
default_thresholds = [15, 20, 20]
5591

5692
def create_spatial_statistics_dict(functions, keyword, options, threshold):
5793
spatial_statistics_dict = []
@@ -62,52 +98,58 @@ def create_spatial_statistics_dict(functions, keyword, options, threshold):
6298

6399
spatial_statistics_dict = create_spatial_statistics_dict(postprocess_functions, postprocess_options,
64100
function_keywords, default_thresholds)
65-
66-
if sum(x["argument"] is not None for x in spatial_statistics_dict) == 0:
67-
raise ValueError("Choose a postprocess function from 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'.")
68-
elif sum(x["argument"] is not None for x in spatial_statistics_dict) > 1:
69-
raise ValueError("The script only supports a single postprocess function.")
70-
else:
71-
for d in spatial_statistics_dict:
72-
if d["argument"] is not None:
73-
spatial_statistics = d["function"]
74-
spatial_statistics_kwargs = {d["keyword"]: d["argument"]}
75-
threshold = d["threshold"]
76-
77-
seg_path = os.path.join(args.output_folder, "segmentation.zarr")
78-
101+
if seg_path is not None:
102+
if sum(x["argument"] is not None for x in spatial_statistics_dict) == 0:
103+
raise ValueError("Choose a postprocess function: 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'.")
104+
elif sum(x["argument"] is not None for x in spatial_statistics_dict) > 1:
105+
raise ValueError("The script only supports a single postprocess function.")
106+
else:
107+
for d in spatial_statistics_dict:
108+
if d["argument"] is not None:
109+
spatial_statistics = d["function"]
110+
spatial_statistics_kwargs = {d["keyword"]: d["argument"]}
111+
threshold = d["threshold"]
112+
113+
# check TSV-file containing data in MoBIE format
79114
tsv_table = None
80-
81-
if args.s3_input is not None:
82-
s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name,
83-
service_endpoint=args.s3_service_endpoint,
84-
credential_file=args.s3_credentials)
85-
with zarr.open(s3_path, mode="r") as f:
86-
segmentation = f[args.input_key]
87-
88-
if args.tsv is not None:
115+
if args.tsv is not None:
116+
if args.s3:
89117
tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=args.s3_bucket_name,
90118
service_endpoint=args.s3_service_endpoint,
91119
credential_file=args.s3_credentials)
92120
with fs.open(tsv_path, 'r') as f:
93121
tsv_table = pd.read_csv(f, sep="\t")
94-
95-
else:
96-
with zarr.open(seg_path, mode="r") as f:
97-
segmentation = f[args.input_key]
98-
99-
if args.tsv is not None:
122+
else:
100123
with open(args.tsv, 'r') as f:
101124
tsv_table = pd.read_csv(f, sep="\t")
102125

103-
n_pre, n_post = filter_segmentation(segmentation, output_path=seg_path,
104-
spatial_statistics=spatial_statistics,
105-
threshold=threshold,
106-
min_size=args.min_size, table=tsv_table,
107-
resolution=args.resolution,
108-
output_key=args.output_key, **spatial_statistics_kwargs)
126+
if seg_path is None:
127+
post_table = postprocess_sgn_seg(
128+
tsv_table.copy(), min_size=args.min_size, threshold_erode=args.threshold,
129+
min_component_length=args.min_component_length, min_edge_distance=args.min_edge_dist,
130+
iterations_erode=args.iterations_erode,
131+
)
132+
133+
if args.tsv_out is None:
134+
out_path = "default.tsv"
135+
else:
136+
out_path = args.tsv_out
137+
post_table.to_csv(out_path, sep="\t", index=False)
138+
139+
n_pre = len(tsv_table)
140+
n_post = len(post_table["component_labels"][post_table["component_labels"] == 1])
109141

110-
print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}")
142+
print(f"Number of pre-filtered objects: {n_pre}\nNumber of objects in largest component: {n_post}")
143+
144+
else:
145+
n_pre, n_post = filter_segmentation(segmentation, output_path=seg_path,
146+
spatial_statistics=spatial_statistics,
147+
threshold=threshold,
148+
min_size=args.min_size, table=tsv_table,
149+
resolution=args.resolution,
150+
output_key=args.output_key, **spatial_statistics_kwargs)
151+
152+
print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}")
111153

112154

113155
if __name__ == "__main__":

0 commit comments

Comments
 (0)