@@ -43,6 +43,8 @@ def fetch_data_for_evaluation(
4343 seg_name : str = "SGN_v2" ,
4444 z_extent : int = 0 ,
4545 components_for_postprocessing : Optional [List [int ]] = None ,
46+ cochlea : Optional [str ] = None ,
47+ extra_data : Optional [str ] = None ,
4648) -> Tuple [np .ndarray , pd .DataFrame ]:
4749 """Fetch segmentation from S3 matching the annotation path for evaluation.
4850
@@ -53,28 +55,31 @@ def fetch_data_for_evaluation(
5355 z_extent: Additional z-slices to load from the segmentation.
5456 components_for_postprocessing: The component ids for restricting the segmentation to.
5557 Choose [1] for the default componentn containing the helix.
58+ cochlea: Optional name of the cochlea.
59+ extra_data: Extra data to fetch.
5660
5761 Returns:
5862 The segmentation downloaded from the S3 bucket.
5963 The annotations loaded from pandas and matching the segmentation.
6064 """
6165 # Load the annotations and normalize them for the given z-extent.
6266 annotations = pd .read_csv (annotation_path )
63- annotations = annotations .drop (columns = "index" )
67+ if "index" in annotations .columns :
68+ annotations = annotations .drop (columns = "index" )
6469 if z_extent == 0 : # If we don't have a z-extent then we just drop the first axis and rename the other two.
6570 annotations = annotations .drop (columns = "axis-0" )
6671 annotations = annotations .rename (columns = {"axis-1" : "axis-0" , "axis-2" : "axis-1" })
67- else : # Otherwise we have to center the first axis.
68- # TODO
69- raise NotImplementedError
7072
7173 # Load the segmentaiton from cache path if it is given and if it is already cached.
7274 if cache_path is not None and os .path .exists (cache_path ):
7375 segmentation = imageio .imread (cache_path )
7476 return segmentation , annotations
7577
7678 # Parse which ID and which cochlea from the name.
77- cochlea , slice_id = _parse_annotation_path (annotation_path )
79+ if cochlea is None :
80+ cochlea , slice_id = _parse_annotation_path (annotation_path )
81+ else :
82+ _ , slice_id = _parse_annotation_path (annotation_path )
7883
7984 # Open the S3 connection, get the path to the SGN segmentation in S3.
8085 internal_path = os .path .join (cochlea , "images" , "ome-zarr" , f"{ seg_name } .ome.zarr" )
@@ -111,6 +116,14 @@ def fetch_data_for_evaluation(
111116 if cache_path is not None :
112117 imageio .imwrite (cache_path , segmentation , compression = "zlib" )
113118
119+ if extra_data is not None :
120+ internal_path = os .path .join (cochlea , "images" , "ome-zarr" , f"{ extra_data } .ome.zarr" )
121+ s3_store , fs = get_s3_path (internal_path , bucket_name = BUCKET_NAME , service_endpoint = SERVICE_ENDPOINT )
122+ input_key = "s0"
123+ with zarr .open (s3_store , mode = "r" ) as f :
124+ extra_im_data = f [input_key ][roi ]
125+ return segmentation , annotations , extra_im_data
126+
114127 return segmentation , annotations
115128
116129
@@ -347,6 +360,62 @@ def union(a, b):
347360 return consensus_df , unmatched_df
348361
349362
363+ def match_detections (
364+ detections : np .ndarray ,
365+ annotations : np .ndarray ,
366+ max_dist : float
367+ ):
368+ """One-to-one matching between 3-D detections and ground-truth points.
369+
370+ Args:
371+ detections: N x 3 candidate detections.
372+ annotations: M x 3 ground-truth annotations for the reference points.
373+ max_dist: Maximum Euclidean distance allowed for a match.
374+
375+ Returns:
376+ Indices in `detections` that were matched (true positives).
377+ Indices in `annotations` that were matched (true positives).
378+ Unmatched detection indices (false positives).
379+ Unmatched annotation indices (false negatives).
380+ """
381+ det = np .asarray (detections , dtype = float )
382+ ann = np .asarray (annotations , dtype = float )
383+ N , M = len (det ), len (ann )
384+
385+ # trivial corner cases --------------------------------------------------------
386+ if N == 0 :
387+ return np .empty (0 , int ), np .empty (0 , int ), np .empty (0 , int ), np .arange (M )
388+ if M == 0 :
389+ return np .empty (0 , int ), np .empty (0 , int ), np .arange (N ), np .empty (0 , int )
390+
391+ # 1. build sparse radius-filtered distance matrix -----------------------------
392+ tree_det = cKDTree (det )
393+ tree_ann = cKDTree (ann )
394+ coo = tree_det .sparse_distance_matrix (tree_ann , max_dist , output_type = "coo_matrix" )
395+
396+ if coo .nnz == 0 : # nothing is close enough
397+ return np .empty (0 , int ), np .empty (0 , int ), np .arange (N ), np .arange (M )
398+
399+ cost = np .full ((N , M ), 5 * max_dist , dtype = float )
400+ cost [coo .row , coo .col ] = coo .data # fill only existing edges
401+
402+ # 2. optimal one-to-one assignment (Hungarian) --------------------------------
403+ row_ind , col_ind = linear_sum_assignment (cost )
404+
405+ # Filter assignments that were padded with +∞ cost for non-existent edges
406+ # (linear_sum_assignment automatically does that padding internally).
407+ valid_mask = cost [row_ind , col_ind ] <= max_dist
408+ tp_det_ids = row_ind [valid_mask ]
409+ tp_ann_ids = col_ind [valid_mask ]
410+ assert len (tp_det_ids ) == len (tp_ann_ids )
411+
412+ # 3. derive FP / FN -----------------------------------------------------------
413+ fp_det_ids = np .setdiff1d (np .arange (N ), tp_det_ids , assume_unique = True )
414+ fn_ann_ids = np .setdiff1d (np .arange (M ), tp_ann_ids , assume_unique = True )
415+
416+ return tp_det_ids , tp_ann_ids , fp_det_ids , fn_ann_ids
417+
418+
350419def for_visualization (segmentation , annotations , matches ):
351420 green_red = ["#00FF00" , "#FF0000" ]
352421
0 commit comments