22
33from __future__ import annotations
44
5- import os
6- import sys
75from pathlib import Path
86from typing import TYPE_CHECKING , Unpack
97
108import dask .array as da
119import dask .dataframe as dd
1210import numpy as np
1311import pandas as pd
14- from shapely .geometry import Point
1512from skimage .feature import peak_local_max
1613from skimage .measure import label , regionprops
1714
18- from tiatoolbox .models .engine .io_config import IOSegmentorConfig
15+ from tiatoolbox import logger
16+ from tiatoolbox .annotation import AnnotationStore
1917from tiatoolbox .models .engine .semantic_segmentor import (
2018 SemanticSegmentor ,
2119 SemanticSegmentorRunParams ,
2220)
2321from tiatoolbox .models .models_abc import ModelABC
24- from tiatoolbox .annotation import Annotation , SQLiteStore , AnnotationStore
2522from tiatoolbox .utils .misc import df_to_store_nucleus_detector
26- from tiatoolbox import logger
2723
2824if TYPE_CHECKING : # pragma: no cover
29- import os
30- from tiatoolbox .models .engine .io_config import IOSegmentorConfig
3125 from tiatoolbox .models .models_abc import ModelABC
32- from tiatoolbox .wsicore import WSIReader
3326
3427
3528def probability_to_peak_map (
36- img2d : np .ndarray , min_distance : int , threshold_abs : float , threshold_rel : float = 0.0
29+ img2d : np .ndarray ,
30+ min_distance : int ,
31+ threshold_abs : float ,
32+ threshold_rel : float = 0.0 ,
3733) -> np .ndarray :
3834 """Build a boolean mask (H, W) of objects from a 2D probability map using peak_local_max.
39-
35+
4036 Args:
4137 img2d (np.ndarray): 2D probability map.
4238 min_distance (int): Minimum distance between peaks.
4339 threshold_abs (float): Absolute threshold for peak detection.
4440 threshold_rel (float, optional): Relative threshold for peak detection. Defaults to 0.0.
41+
4542 Returns:
4643 mask (np.ndarray): Boolean mask (H, W) with True at peak locations.
4744 """
4845 H , W = img2d .shape
4946 mask = np .zeros ((H , W ), dtype = bool )
5047 coords = peak_local_max (
51- img2d , min_distance = min_distance , threshold_abs = threshold_abs , threshold_rel = threshold_rel
48+ img2d ,
49+ min_distance = min_distance ,
50+ threshold_abs = threshold_abs ,
51+ threshold_rel = threshold_rel ,
5252 )
5353 if coords .size :
5454 r , c = coords [:, 0 ], coords [:, 1 ]
@@ -67,7 +67,7 @@ def peak_detection_mapoverlap(
6767) -> np .ndarray :
6868 """Runs inside Dask.da.map_overlap on a padded NumPy block: (h_pad, w_pad, C).
6969 Builds a processed mask per channel, runs peak_local_max then
70- label+regionprops, and writes probability (mean_intensity) at centroid pixels.
70+ label+regionprops, and writes probability (mean_intensity) at centroid pixels.
7171 Keeps only centroids whose (row,col) lie in the interior window:
7272 rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
7373 Returns same spatial shape as input block: (h_pad, w_pad, C), float32.
@@ -81,6 +81,7 @@ def peak_detection_mapoverlap(
8181 depth_w: Halo size in pixels for width (cols).
8282 calculate_probabilities: If True, write mean_intensity at centroids;
8383 else write 1.0 at centroids.
84+
8485 Returns:
8586 out: NumPy array (H, W, C) with probabilities at centroids, 0 elsewhere.
8687 """
@@ -120,7 +121,9 @@ def peak_detection_mapoverlap(
120121 return out
121122
122123
123- def detection_with_map_overlap (probs : da .Array , min_distance : int , threshold_abs : float , depth_pixels : int ) -> da .Array :
124+ def detection_with_map_overlap (
125+ probs : da .Array , min_distance : int , threshold_abs : float , depth_pixels : int
126+ ) -> da .Array :
124127 """probs: Dask array (H, W, C), float.
125128 depth_pixels: halo in pixels for H/W (use >= min_distance and >= any morphology radius).
126129
@@ -143,18 +146,21 @@ def detection_with_map_overlap(probs: da.Array, min_distance: int, threshold_abs
143146 return scores
144147
145148
146- def centroids_map_to_dask_dataframe (scores : da .Array , x_offset : int = 0 , y_offset : int = 0 ) -> dd .DataFrame :
149+ def centroids_map_to_dask_dataframe (
150+ scores : da .Array , x_offset : int = 0 , y_offset : int = 0
151+ ) -> dd .DataFrame :
147152 """Convert centroid map (H, W, C) into a Dask DataFrame with columns: x, y, type, prob.
148153
149154 Args:
150155 scores: Dask array (H, W, C) with probabilities at centroids, 0 elsewhere.
151156 x_offset: global x offset to add to all x coordinates.
152157 y_offset: global y offset to add to all y coordinates.
158+
153159 Returns:
154160 ddf: Dask DataFrame with columns: x, y, type, prob.
155161 """
156162 # 1) Build a boolean mask of detections
157-
163+
158164 mask = scores > 0
159165 # 2) Get coordinates and class of detections (lazy 1D Dask arrays)
160166
@@ -172,7 +178,7 @@ def centroids_map_to_dask_dataframe(scores: da.Array, x_offset: int = 0, y_offse
172178 dd .from_dask_array (ss .astype ("float32" ), columns = "prob" ),
173179 ],
174180 axis = 1 ,
175- ignore_unknown_divisions = True
181+ ignore_unknown_divisions = True ,
176182 )
177183
178184 # 5) Apply global offsets (if needed)
@@ -184,7 +190,9 @@ def centroids_map_to_dask_dataframe(scores: da.Array, x_offset: int = 0, y_offse
184190 return ddf
185191
186192
187- def nucleus_detection_nms (df : pd .DataFrame , radius : int , overlap_threshold :float = 0.5 ) -> pd .DataFrame :
193+ def nucleus_detection_nms (
194+ df : pd .DataFrame , radius : int , overlap_threshold : float = 0.5
195+ ) -> pd .DataFrame :
188196 """Greedy NMS across ALL detections.
189197
190198 Keeps the highest-prob detection, removes any other point within 'radius' pixels > overlap_threshold.
@@ -215,7 +223,7 @@ def nucleus_detection_nms(df: pd.DataFrame, radius: int, overlap_threshold:float
215223 coords = sub [["x" , "y" ]].to_numpy (dtype = np .float64 )
216224 r = float (radius )
217225 two_r = 2.0 * r
218- two_r2 = ( two_r * two_r ) # distance^2 cutoff for any overlap
226+ two_r2 = two_r * two_r # distance^2 cutoff for any overlap
219227
220228 suppressed = np .zeros (len (sub ), dtype = bool )
221229 keep_idx = []
@@ -232,18 +240,19 @@ def nucleus_detection_nms(df: pd.DataFrame, radius: int, overlap_threshold:float
232240 d2 = dx * dx + dy * dy
233241
234242 # Only points with d < 2r can have nonzero overlap
235- cand = ( d2 <= two_r2 )
243+ cand = d2 <= two_r2
236244 cand [i ] = False # don't suppress the kept point itself
237245 if not np .any (cand ):
238246 continue
239247
240248 d = np .sqrt (d2 [cand ])
241249
242-
243250 # Safe cosine argument = (distance ÷ diameter), Clamp for numerical stability
244251 u = np .clip (d / (2.0 * r ), - 1.0 , 1.0 )
245252 # Exact intersection area of two equal-radius circles.
246- inter = 2.0 * (r * r ) * np .arccos (u ) - 0.5 * d * np .sqrt (np .clip (4.0 * r * r - d * d , 0.0 , None ))
253+ inter = 2.0 * (r * r ) * np .arccos (u ) - 0.5 * d * np .sqrt (
254+ np .clip (4.0 * r * r - d * d , 0.0 , None )
255+ )
247256
248257 union = 2.0 * np .pi * (r * r ) - inter
249258 iou = inter / union
@@ -252,7 +261,7 @@ def nucleus_detection_nms(df: pd.DataFrame, radius: int, overlap_threshold:float
252261 idx_cand = np .where (cand )[0 ]
253262 to_suppress = idx_cand [iou >= overlap_threshold ]
254263 suppressed [to_suppress ] = True
255-
264+
256265 kept = sub .iloc [keep_idx ].copy ()
257266 return kept
258267
@@ -363,6 +372,7 @@ def post_process_patches(
363372 raw_predictions (da.Array): The raw predictions from the model.
364373 prediction_shape (tuple[int, ...]): The shape of the predictions.
365374 prediction_dtype (type): The data type of the predictions.
375+
366376 Returns:
367377 A list of DataFrames containing the post-processed predictions for each patch.
368378
@@ -376,7 +386,6 @@ def post_process_patches(
376386 batch_predictions .append (self .model .postproc_func (raw_predictions [i ]))
377387 return batch_predictions
378388
379-
380389 def post_process_wsi (
381390 self : NucleusDetector ,
382391 raw_predictions : da .Array ,
@@ -396,8 +405,9 @@ def post_process_wsi(
396405 logger .info (f"Raw probabilities dtype: { prediction_dtype } " )
397406 logger .info (f"Chunk size: { raw_predictions .chunks } " )
398407
399- detection_df = self .model .postproc (raw_predictions , prediction_shape , prediction_dtype )
400-
408+ detection_df = self .model .postproc (
409+ raw_predictions , prediction_shape , prediction_dtype
410+ )
401411
402412 return detection_df
403413
@@ -441,11 +451,9 @@ def save_predictions(
441451
442452 save_paths .append (out_file )
443453 return save_paths
444- else :
445- return df_to_store_nucleus_detector (
446- processed_predictions ['predictions' ],
447- scale_factor = scale_factor ,
448- save_path = save_path ,
449- class_dict = class_dict ,
450- )
451-
454+ return df_to_store_nucleus_detector (
455+ processed_predictions ["predictions" ],
456+ scale_factor = scale_factor ,
457+ save_path = save_path ,
458+ class_dict = class_dict ,
459+ )
0 commit comments