77import flamingo_tools .s3_utils as s3_utils
88from flamingo_tools .segmentation import filter_segmentation
99from flamingo_tools .segmentation .postprocessing import nearest_neighbor_distance , local_ripleys_k , neighbors_in_radius
10+ from flamingo_tools .segmentation .postprocessing import postprocess_sgn_seg
1011
1112
1213# TODO needs updates
@@ -15,18 +16,34 @@ def main():
1516 parser = argparse .ArgumentParser (
1617 description = "Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket." )
1718
18- parser .add_argument ("-o" , "--output_folder" , type = str , required = True )
19+ parser .add_argument ("-o" , "--output_folder" , type = str , default = None )
1920
2021 parser .add_argument ("-t" , "--tsv" , type = str , default = None ,
2122 help = "TSV-file in MoBIE format which contains information about segmentation." )
23+ parser .add_argument ("--tsv_out" , type = str , default = None ,
24+ help = "File path to save post-processed dataframe. Default: default.tsv" )
25+
2226 parser .add_argument ('-k' , "--input_key" , type = str , default = "segmentation" ,
2327 help = "The key / internal path of the segmentation." )
2428 parser .add_argument ("--output_key" , type = str , default = "segmentation_postprocessed" ,
2529 help = "The key / internal path of the output." )
2630 parser .add_argument ('-r' , "--resolution" , type = float , default = 0.38 ,
2731 help = "Resolution of segmentation in micrometer." )
2832
29- parser .add_argument ("--s3_input" , type = str , default = None , help = "Input file path on S3 bucket." )
33+ # options for post-processing
34+ parser .add_argument ("--min_size" , type = int , default = 1000 ,
35+ help = "Minimal number of pixels for filtering small instances." )
36+ parser .add_argument ("--threshold" , type = float , default = None ,
37+ help = "Threshold for spatial statistics." )
38+ parser .add_argument ("--min_component_length" , type = int , default = 50 ,
39+ help = "Minimal length for filtering out connected components." )
40+ parser .add_argument ("--min_edge_dist" , type = float , default = 30 ,
41+ help = "Minimal distance in micrometer between points to create edges for connected components." )
42+ parser .add_argument ("--iterations_erode" , type = int , default = None ,
43+ help = "Number of iterations for erosion, normally determined automatically." )
44+
45+ # options for S3 bucket
46+ parser .add_argument ("--s3" , action = "store_true" , help = "Flag for using S3 bucket." )
3047 parser .add_argument ("--s3_credentials" , type = str , default = None ,
3148 help = "Input file containing S3 credentials. "
3249 "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported." )
@@ -35,23 +52,42 @@ def main():
3552 parser .add_argument ("--s3_service_endpoint" , type = str , default = None ,
3653 help = "S3 service endpoint. Optional if SERVICE_ENDPOINT was exported." )
3754
38- parser .add_argument ("--min_size" , type = int , default = 1000 , help = "Minimal number of voxel size for counting object" )
39-
55+ # options for spatial statistics
4056 parser .add_argument ("--n_neighbors" , type = int , default = None ,
4157 help = "Value for calculating distance to 'n' nearest neighbors." )
42-
4358 parser .add_argument ("--local_ripley_radius" , type = int , default = None ,
4459 help = "Value for radius for calculating local Ripley's K function." )
45-
4660 parser .add_argument ("--r_neighbors" , type = int , default = None ,
4761 help = "Value for radius for calculating number of neighbors in range." )
4862
4963 args = parser .parse_args ()
5064
65+ if args .output_folder is None and args .tsv is None :
66+ raise ValueError ("Either supply an output folder containing 'segmentation.zarr' or a TSV-file in MoBIE format." )
67+
68+ # check output folder
69+ if args .output_folder is not None :
70+ seg_path = os .path .join (args .output_folder , "segmentation.zarr" )
71+ if args .s3 :
72+ s3_path , fs = s3_utils .get_s3_path (args .s3_input , bucket_name = args .s3_bucket_name ,
73+ service_endpoint = args .s3_service_endpoint ,
74+ credential_file = args .s3_credentials )
75+ with zarr .open (s3_path , mode = "r" ) as f :
76+ segmentation = f [args .input_key ]
77+ else :
78+ with zarr .open (seg_path , mode = "r" ) as f :
79+ segmentation = f [args .input_key ]
80+ else :
81+ seg_path = None
82+
83+ # check input for spatial statistics
5184 postprocess_functions = [nearest_neighbor_distance , local_ripleys_k , neighbors_in_radius ]
5285 function_keywords = ["n_neighbors" , "radius" , "radius" ]
5386 postprocess_options = [args .n_neighbors , args .local_ripley_radius , args .r_neighbors ]
54- default_thresholds = [15 , 20 , 20 ]
87+ default_thresholds = [args .threshold for _ in postprocess_functions ]
88+
89+ if seg_path is not None and args .threshold is None :
90+ default_thresholds = [15 , 20 , 20 ]
5591
5692 def create_spatial_statistics_dict (functions , keyword , options , threshold ):
5793 spatial_statistics_dict = []
@@ -62,52 +98,58 @@ def create_spatial_statistics_dict(functions, keyword, options, threshold):
6298
6399 spatial_statistics_dict = create_spatial_statistics_dict (postprocess_functions , postprocess_options ,
64100 function_keywords , default_thresholds )
65-
66- if sum (x ["argument" ] is not None for x in spatial_statistics_dict ) == 0 :
67- raise ValueError ("Choose a postprocess function from 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'." )
68- elif sum (x ["argument" ] is not None for x in spatial_statistics_dict ) > 1 :
69- raise ValueError ("The script only supports a single postprocess function." )
70- else :
71- for d in spatial_statistics_dict :
72- if d ["argument" ] is not None :
73- spatial_statistics = d ["function" ]
74- spatial_statistics_kwargs = {d ["keyword" ]: d ["argument" ]}
75- threshold = d ["threshold" ]
76-
77- seg_path = os .path .join (args .output_folder , "segmentation.zarr" )
78-
101+ if seg_path is not None :
102+ if sum (x ["argument" ] is not None for x in spatial_statistics_dict ) == 0 :
103+ raise ValueError ("Choose a postprocess function: 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'." )
104+ elif sum (x ["argument" ] is not None for x in spatial_statistics_dict ) > 1 :
105+ raise ValueError ("The script only supports a single postprocess function." )
106+ else :
107+ for d in spatial_statistics_dict :
108+ if d ["argument" ] is not None :
109+ spatial_statistics = d ["function" ]
110+ spatial_statistics_kwargs = {d ["keyword" ]: d ["argument" ]}
111+ threshold = d ["threshold" ]
112+
113+ # check TSV-file containing data in MoBIE format
79114 tsv_table = None
80-
81- if args .s3_input is not None :
82- s3_path , fs = s3_utils .get_s3_path (args .s3_input , bucket_name = args .s3_bucket_name ,
83- service_endpoint = args .s3_service_endpoint ,
84- credential_file = args .s3_credentials )
85- with zarr .open (s3_path , mode = "r" ) as f :
86- segmentation = f [args .input_key ]
87-
88- if args .tsv is not None :
115+ if args .tsv is not None :
116+ if args .s3 :
89117 tsv_path , fs = s3_utils .get_s3_path (args .tsv , bucket_name = args .s3_bucket_name ,
90118 service_endpoint = args .s3_service_endpoint ,
91119 credential_file = args .s3_credentials )
92120 with fs .open (tsv_path , 'r' ) as f :
93121 tsv_table = pd .read_csv (f , sep = "\t " )
94-
95- else :
96- with zarr .open (seg_path , mode = "r" ) as f :
97- segmentation = f [args .input_key ]
98-
99- if args .tsv is not None :
122+ else :
100123 with open (args .tsv , 'r' ) as f :
101124 tsv_table = pd .read_csv (f , sep = "\t " )
102125
103- n_pre , n_post = filter_segmentation (segmentation , output_path = seg_path ,
104- spatial_statistics = spatial_statistics ,
105- threshold = threshold ,
106- min_size = args .min_size , table = tsv_table ,
107- resolution = args .resolution ,
108- output_key = args .output_key , ** spatial_statistics_kwargs )
126+ if seg_path is None :
127+ post_table = postprocess_sgn_seg (
128+ tsv_table .copy (), min_size = args .min_size , threshold_erode = args .threshold ,
129+ min_component_length = args .min_component_length , min_edge_distance = args .min_edge_dist ,
130+ iterations_erode = args .iterations_erode ,
131+ )
132+
133+ if args .tsv_out is None :
134+ out_path = "default.tsv"
135+ else :
136+ out_path = args .tsv_out
137+ post_table .to_csv (out_path , sep = "\t " , index = False )
138+
139+ n_pre = len (tsv_table )
140+ n_post = len (post_table ["component_labels" ][post_table ["component_labels" ] == 1 ])
109141
110- print (f"Number of pre-filtered objects: { n_pre } \n Number of post-filtered objects: { n_post } " )
142+ print (f"Number of pre-filtered objects: { n_pre } \n Number of objects in largest component: { n_post } " )
143+
144+ else :
145+ n_pre , n_post = filter_segmentation (segmentation , output_path = seg_path ,
146+ spatial_statistics = spatial_statistics ,
147+ threshold = threshold ,
148+ min_size = args .min_size , table = tsv_table ,
149+ resolution = args .resolution ,
150+ output_key = args .output_key , ** spatial_statistics_kwargs )
151+
152+ print (f"Number of pre-filtered objects: { n_pre } \n Number of post-filtered objects: { n_post } " )
111153
112154
113155if __name__ == "__main__" :
0 commit comments