77import pandas as pd
88import zarr
99
10+ from scipy .ndimage import distance_transform_edt
11+ from scipy .optimize import linear_sum_assignment
12+ from skimage .measure import regionprops_table
13+ from skimage .segmentation import relabel_sequential
14+ from tqdm import tqdm
15+
1016from .s3_utils import get_s3_path , BUCKET_NAME , SERVICE_ENDPOINT
1117
1218
@@ -21,16 +27,26 @@ def _normalize_cochlea_name(name):
2127 return f"{ prefix } _{ number :06d} _{ postfix } "
2228
2329
24- # For a less naive annotation we may need to also fetch +- a few slices,
25- # so that we have a bit of tolerance with the distance based matching.
30+ # TODO enable table component filtering with MoBIE table
2631def fetch_data_for_evaluation (
2732 annotation_path : str ,
2833 cache_path : Optional [str ] = None ,
2934 seg_name : str = "SGN" ,
35+ z_extent : int = 0 ,
3036) -> Tuple [np .ndarray , pd .DataFrame ]:
3137 """
3238 """
39+ # Load the annotations and normalize them for the given z-extent.
3340 annotations = pd .read_csv (annotation_path )
41+ annotations = annotations .drop (columns = "index" )
42+ if z_extent == 0 : # If we don't have a z-extent then we just drop the first axis and rename the other two.
43+ annotations = annotations .drop (columns = "axis-0" )
44+ annotations = annotations .rename (columns = {"axis-1" : "axis-0" , "axis-2" : "axis-1" })
45+ else : # Otherwise we have to center the first axis.
46+ # TODO
47+ raise NotImplementedError
48+
49+ # Load the segmentaiton from cache path if it is given and if it is already cached.
3450 if cache_path is not None and os .path .exists (cache_path ):
3551 segmentation = imageio .imread (cache_path )
3652 return segmentation , annotations
@@ -45,10 +61,17 @@ def fetch_data_for_evaluation(
4561 internal_path = os .path .join (cochlea , "images" , "ome-zarr" , f"{ seg_name } .ome.zarr" )
4662 s3_store , fs = get_s3_path (internal_path , bucket_name = BUCKET_NAME , service_endpoint = SERVICE_ENDPOINT )
4763
48- # Download the segmentation for this slice.
64+ # Compute the roi for the given z-extent.
65+ if z_extent == 0 :
66+ roi = slice_id
67+ else :
68+ roi = slice (slice_id - z_extent , slice_id + z_extent )
69+
70+ # Download the segmentation for this slice and the given z-extent.
4971 input_key = "s0"
5072 with zarr .open (s3_store , mode = "r" ) as f :
51- segmentation = f [input_key ][slice_id ]
73+ segmentation = f [input_key ][roi ]
74+ segmentation , _ , _ = relabel_sequential (segmentation )
5275
5376 # Cache it if required.
5477 if cache_path is not None :
@@ -57,7 +80,61 @@ def fetch_data_for_evaluation(
5780 return segmentation , annotations
5881
5982
60- def evaluate_annotated_slice (
83+ def compute_matches_for_annotated_slice (
84+ segmentation : np .typing .ArrayLike ,
85+ annotations : pd .DataFrame ,
86+ matching_tolerance : float = 0.0 ,
87+ ) -> Dict [str , np .ndarray ]:
88+ """Computes the ids of matches and non-matches for a annotated validation slice.
89+
90+ Computes true positive ids (for objects and annotations), false positive ids and false negative ids
91+ by solving a linear cost assignment of distances between objects and annotations.
92+
93+ Args:
94+ segmentation: The segmentation for this slide. We assume that it is relabeled consecutively.
95+ annotations: The annotations, marking cell centers.
96+ matching_tolerance: The maximum distance for matching an annotation to a segmented object.
97+
98+ Returns:
99+ A dictionary with keys 'tp_objects', 'tp_annotations' 'fp' and 'fn', mapping to the respective ids.
100+ """
101+ assert segmentation .ndim in (2 , 3 )
102+ segmentation_ids = np .unique (segmentation )[1 :]
103+ n_objects , n_annotations = len (segmentation_ids ), len (annotations )
104+
105+ # In order to get the full distance matrix, we compute the distance to all objects for each annotation.
106+ # This is not very efficient, but it's the most straight-forward and most rigorous approach.
107+ scores = np .zeros ((n_objects , n_annotations ), dtype = "float" )
108+ coordinates = ["axis-0" , "axis-1" ] if segmentation .ndim == 2 else ["axis-0" , "axis-1" , "axis-2" ]
109+ for i , row in tqdm (annotations .iterrows (), total = n_annotations , desc = "Compute pairwise distances" ):
110+ coordinate = tuple (int (np .round (row [coord ])) for coord in coordinates )
111+ distance_input = np .ones (segmentation .shape , dtype = "bool" )
112+ distance_input [coordinate ] = False
113+ distances , indices = distance_transform_edt (distance_input , return_indices = True )
114+
115+ props = regionprops_table (segmentation , intensity_image = distances , properties = ("label" , "min_intensity" ))
116+ distances = props ["min_intensity" ]
117+ assert len (distances ) == scores .shape [0 ]
118+ scores [:, i ] = distances
119+
120+ # Find the assignment of points to objects.
121+ # These correspond to the TP ids in the point / object annotations.
122+ tp_ids_objects , tp_ids_annotations = linear_sum_assignment (scores )
123+ match_ok = scores [tp_ids_objects , tp_ids_annotations ] <= matching_tolerance
124+ tp_ids_objects , tp_ids_annotations = tp_ids_objects [match_ok ], tp_ids_annotations [match_ok ]
125+ tp_ids_objects = segmentation_ids [tp_ids_objects ]
126+ assert len (tp_ids_objects ) == len (tp_ids_annotations )
127+
128+ # Find the false positives: objects that are not part of the matches.
129+ fp_ids = np .setdiff1d (segmentation_ids , tp_ids_objects )
130+
131+ # Find the false negatives: annotations that are not part of the matches.
132+ fn_ids = np .setdiff1d (np .arange (n_annotations ), tp_ids_annotations )
133+
134+ return {"tp_objects" : tp_ids_objects , "tp_annotations" : tp_ids_annotations , "fp" : fp_ids , "fn" : fn_ids }
135+
136+
137+ def compute_scores_for_annotated_slice (
61138 segmentation : np .typing .ArrayLike ,
62139 annotations : pd .DataFrame ,
63140 matching_tolerance : float = 0.0 ,
@@ -67,20 +144,39 @@ def evaluate_annotated_slice(
67144 Computes true positives, false positives and false negatives for scoring.
68145
69146 Args:
70- segmentation: The segmentation for this slide.
147+ segmentation: The segmentation for this slide. We assume that it is relabeled consecutively.
71148 annotations: The annotations, marking cell centers.
72- matching_tolerance: .. .
149+ matching_tolerance: The maximum distance for matching an annotation to a segmented object .
73150
74151 Returns:
75152 A dictionary with keys 'tp', 'fp' and 'fn', mapping to the respective counts.
76153 """
77- # Compute the distance transform and nearest id fields.
154+ result = compute_matches_for_annotated_slice ( segmentation , annotations , matching_tolerance )
78155
79- # Match all of the points to segmented objects based on their distance.
156+ # To determine the TPs, FPs and FNs.
157+ tp = len (result ["tp_objects" ])
158+ fp = len (result ["fp" ])
159+ fn = len (result ["fn" ])
160+ return {"tp" : tp , "fp" : fp , "fn" : fn }
80161
81- # Determine the TPs, FPs and FNs based on a linear cost assignment.
82- tp = ...
83- fp = ...
84- fn = ...
85162
86- return {"tp" : tp , "fp" : fp , "fn" : fn }
163+ def for_visualization (segmentation , annotations , matches ):
164+ green_red = ["#00FF00" , "#FF0000" ]
165+
166+ seg_vis = np .zeros_like (segmentation )
167+ tps , fps = matches ["tp_objects" ], matches ["fp" ]
168+ seg_vis [np .isin (segmentation , tps )] = 1
169+ seg_vis [np .isin (segmentation , fps )] = 2
170+
171+ # TODO red / green colormap
172+ seg_props = dict (color = {1 : green_red [0 ], 2 : green_red [1 ]})
173+
174+ point_vis = annotations .copy ()
175+ tps = matches ["tp_annotations" ]
176+ point_props = dict (
177+ properties = {"match" : [0 if aid in tps else 1 for aid in range (len (annotations ))]},
178+ border_color = "match" ,
179+ border_color_cycle = green_red ,
180+ )
181+
182+ return seg_vis , point_vis , seg_props , point_props
0 commit comments