@@ -37,6 +37,13 @@ def _parse_annotation_path(annotation_path):
3737 return cochlea , slice_id
3838
3939
40+ def _get_table (fs , cochlea , seg_name ):
41+ internal_path = os .path .join (BUCKET_NAME , cochlea , "tables" , seg_name , "default.tsv" )
42+ with fs .open (internal_path , "r" ) as f :
43+ table = pd .read_csv (f , sep = "\t " )
44+ return table
45+
46+
4047def fetch_data_for_evaluation (
4148 annotation_path : str ,
4249 cache_path : Optional [str ] = None ,
@@ -45,6 +52,7 @@ def fetch_data_for_evaluation(
4552 components_for_postprocessing : Optional [List [int ]] = None ,
4653 cochlea : Optional [str ] = None ,
4754 extra_data : Optional [str ] = None ,
55+ exclude_zero_synapse_count : bool = False ,
4856) -> Tuple [np .ndarray , pd .DataFrame ]:
4957 """Fetch segmentation from S3 matching the annotation path for evaluation.
5058
@@ -57,6 +65,8 @@ def fetch_data_for_evaluation(
5765 Choose [1] for the default componentn containing the helix.
5866 cochlea: Optional name of the cochlea.
5967 extra_data: Extra data to fetch.
68+ exclude_zero_synapse_count: Exclude cells that have zero synapses mapped.
69+ This is relevant for the IHC evaluation.
6070
6171 Returns:
6272 The segmentation downloaded from the S3 bucket.
@@ -96,20 +106,28 @@ def fetch_data_for_evaluation(
96106 with zarr .open (s3_store , mode = "r" ) as f :
97107 segmentation = f [input_key ][roi ]
98108
109+ table = None
99110 if components_for_postprocessing is not None :
100111 # Filter the IDs so that only the ones part of 'components_for_postprocessing_remain'.
101-
102- # First, we download the MoBIE table for this segmentation.
103- internal_path = os .path .join (BUCKET_NAME , cochlea , "tables" , seg_name , "default.tsv" )
104- with fs .open (internal_path , "r" ) as f :
105- table = pd .read_csv (f , sep = "\t " )
112+ table = _get_table (fs , cochlea , seg_name )
106113
107114 # Then we get the ids for the components and us them to filter the segmentation.
108115 component_mask = np .isin (table .component_labels .values , components_for_postprocessing )
109116 keep_label_ids = table .label_id .values [component_mask ].astype ("int64" )
110117 filter_mask = ~ np .isin (segmentation , keep_label_ids )
111118 segmentation [filter_mask ] = 0
112119
120+ # We also filter the table accordingly.
121+ table = table [table .label_id .isin (keep_label_ids )]
122+
123+ if exclude_zero_synapse_count :
124+ if table is None :
125+ table = _get_table (fs , cochlea , seg_name )
126+
127+ keep_label_ids = table .label_id [table .syn_per_IHC > 0 ].astype ("int64" )
128+ filter_mask = ~ np .isin (segmentation , keep_label_ids )
129+ segmentation [filter_mask ] = 0
130+
113131 segmentation , _ , _ = relabel_sequential (segmentation )
114132
115133 # Cache it if required.
0 commit comments