11"""This module implements nucleus detection engine."""
2+
23from __future__ import annotations
34
45import os
5- from pathlib import Path
66import sys
7- import numpy as np
8- import pandas as pd
7+ from pathlib import Path
8+ from typing import TYPE_CHECKING , Unpack
9+
910import dask .array as da
1011import dask .dataframe as dd
11- from dask .delayed import delayed
12- import dask
12+ import numpy as np
13+ import pandas as pd
14+ from shapely .geometry import Point
1315from skimage .feature import peak_local_max
1416from skimage .measure import label , regionprops
1517
16- from tiatoolbox .models .engine .engine_abc import EngineABCRunParams
18+ from tiatoolbox .models .engine .io_config import IOSegmentorConfig
1719from tiatoolbox .models .engine .semantic_segmentor import (
1820 SemanticSegmentor ,
19- SemanticSegmentorRunParams
21+ SemanticSegmentorRunParams ,
2022)
21- from tiatoolbox .models .engine .io_config import IOSegmentorConfig
2223from tiatoolbox .models .models_abc import ModelABC
23- from shapely .geometry import Point
24- from typing import TYPE_CHECKING , Unpack
2524
2625if TYPE_CHECKING : # pragma: no cover
2726 import os
2827
29- from torch .utils .data import DataLoader
30-
3128 from tiatoolbox .annotation import AnnotationStore
3229 from tiatoolbox .models .engine .io_config import IOSegmentorConfig
3330 from tiatoolbox .models .models_abc import ModelABC
34- from tiatoolbox .type_hints import Resolution
3531 from tiatoolbox .wsicore import WSIReader
3632
33+
3734def dataframe_to_annotation_store (
3835 df : pd .DataFrame ,
3936) -> AnnotationStore :
40- """
41- Convert a pandas DataFrame with columns ['x','y','type','prob']
37+ """Convert a pandas DataFrame with columns ['x','y','type','prob']
4238 to an AnnotationStore and save to disk.
4339 """
44- from tiatoolbox .annotation import SQLiteStore , Annotation
40+ from tiatoolbox .annotation import Annotation , SQLiteStore
4541
4642 ann_store = SQLiteStore ()
4743 for _ , row in df .iterrows ():
4844 x = int (row ["x" ])
4945 y = int (row ["y" ])
5046 obj_type = int (row ["type" ])
5147 prob = float (row ["prob" ])
52- ann = Annotation (geometry = Point (x , y ), properties = {"type" : "nuclei" , "probability" : prob })
48+ ann = Annotation (
49+ geometry = Point (x , y ), properties = {"type" : "nuclei" , "probability" : prob }
50+ )
5351 ann_store .append (ann )
5452 return ann_store
5553
5654
57- def processed_mask_fn (img2d :np .ndarray , min_distance : int , threshold_abs : float | int ) -> np .ndarray :
58- """
59- Build a boolean mask (H, W) of objects from a 2D probability map.
55+ def processed_mask_fn (
56+ img2d : np .ndarray , min_distance : int , threshold_abs : float
57+ ) -> np .ndarray :
58+ """Build a boolean mask (H, W) of objects from a 2D probability map.
6059 Here: 1-pixel objects from peak_local_max. Add morphology inside if you need blobs.
6160 """
6261 H , W = img2d .shape
6362 mask = np .zeros ((H , W ), dtype = bool )
64- coords = peak_local_max (img2d , min_distance = min_distance , threshold_abs = threshold_abs )
63+ coords = peak_local_max (
64+ img2d , min_distance = min_distance , threshold_abs = threshold_abs
65+ )
6566 if coords .size :
6667 r , c = coords [:, 0 ], coords [:, 1 ]
6768 mask [r , c ] = True
6869 return mask
6970
71+
7072def block_regionprops_mapoverlap (
7173 block : np .ndarray ,
7274 block_info ,
7375 min_distance : int ,
74- threshold_abs : float | int ,
76+ threshold_abs : float ,
7577 depth_h : int ,
7678 depth_w : int ,
7779) -> np .ndarray :
78- """
79- Runs inside da.map_overlap on a padded NumPy block: (h_pad, w_pad, C).
80+ """Runs inside da.map_overlap on a padded NumPy block: (h_pad, w_pad, C).
8081 Builds a processed mask per channel, runs label+regionprops, and writes
8182 region score (mean_intensity) at centroid pixels. Keeps only centroids
8283 whose (row,col) lie in the interior window:
@@ -87,11 +88,10 @@ def block_regionprops_mapoverlap(
8788
8889 # --- derive core (pre-overlap) size for THIS block safely ---
8990 info = block_info [0 ]
90- locs = info ["array-location" ] # [(r0,r1),(c0,c1),(ch0,ch1)]
91- core_h = int (locs [0 ][1 ] - locs [0 ][0 ]) # r1 - r0
91+ locs = info ["array-location" ] # [(r0,r1),(c0,c1),(ch0,ch1)]
92+ core_h = int (locs [0 ][1 ] - locs [0 ][0 ]) # r1 - r0
9293 core_w = int (locs [1 ][1 ] - locs [1 ][0 ])
9394
94-
9595 rmin , rmax = depth_h , depth_h + core_h
9696 cmin , cmax = depth_w , depth_w + core_w
9797
@@ -118,9 +118,9 @@ def block_regionprops_mapoverlap(
118118
119119
120120def detect_with_map_overlap (probs , min_distance , threshold_abs , depth_pixels ):
121- """
122- probs: Dask array (H, W, C), float.
121+ """probs: Dask array (H, W, C), float.
123122 depth_pixels: halo in pixels for H/W (use >= min_distance and >= any morphology radius).
123+
124124 Returns:
125125 scores: da.Array (H, W, C) with mean_intensity at centroids, 0 elsewhere.
126126 """
@@ -139,9 +139,9 @@ def detect_with_map_overlap(probs, min_distance, threshold_abs, depth_pixels):
139139 )
140140 return scores
141141
142+
142143def scores_to_ddf (scores : da .Array , x_offset : int , y_offset : int ) -> dd .DataFrame :
143- """
144- Convert (H, W, C) scores -> Dask DataFrame with columns: x, y, type, prob.
144+ """Convert (H, W, C) scores -> Dask DataFrame with columns: x, y, type, prob.
145145 Uses da.extract(mask, scores) to avoid vindex on Dask indexers.
146146 """
147147 # 1) Build a boolean mask of detections
@@ -156,9 +156,9 @@ def scores_to_ddf(scores: da.Array, x_offset: int, y_offset: int) -> dd.DataFram
156156 # 4) Assemble a Dask DataFrame
157157 ddf = dd .concat (
158158 [
159- dd .from_dask_array (xx .astype ("int64" ), columns = "x" ),
160- dd .from_dask_array (yy .astype ("int64" ), columns = "y" ),
161- dd .from_dask_array (cc .astype ("int64" ), columns = "type" ),
159+ dd .from_dask_array (xx .astype ("int64" ), columns = "x" ),
160+ dd .from_dask_array (yy .astype ("int64" ), columns = "y" ),
161+ dd .from_dask_array (cc .astype ("int64" ), columns = "type" ),
162162 dd .from_dask_array (ss .astype ("float32" ), columns = "prob" ),
163163 ],
164164 axis = 1 ,
@@ -172,8 +172,7 @@ def scores_to_ddf(scores: da.Array, x_offset: int, y_offset: int) -> dd.DataFram
172172
173173
174174def greedy_radius_nms_pandas_all (df : pd .DataFrame , radius : int ) -> pd .DataFrame :
175- """
176- Greedy NMS across ALL detections (no per-type grouping).
175+ """Greedy NMS across ALL detections (no per-type grouping).
177176 Keeps the highest-prob point, suppresses any other point within 'radius' pixels.
178177
179178 Expects columns: ['x','y','type','prob'].
@@ -245,7 +244,7 @@ class NucleusDetector(SemanticSegmentor):
245244 device (str):
246245 Device to run the model on, e.g., 'cpu' or 'cuda:0'.
247246 verbose (bool):
248- Whether to output logging information.
247+ Whether to output logging information.
249248
250249
251250 Examples:
@@ -307,11 +306,12 @@ def __init__(
307306 verbose = verbose ,
308307 )
309308
310- def post_process_patches (self ,
309+ def post_process_patches (
310+ self ,
311311 raw_predictions : da .Array ,
312- prediction_shape : tuple [int , ...], # noqa: ARG002
313- prediction_dtype : type , # noqa: ARG002
314- ** kwargs : Unpack [SemanticSegmentorRunParams ], # noqa: ARG002
312+ prediction_shape : tuple [int , ...],
313+ prediction_dtype : type ,
314+ ** kwargs : Unpack [SemanticSegmentorRunParams ],
315315 ) -> da .Array :
316316 """Define how to post-process patch predictions.
317317
@@ -320,9 +320,8 @@ def post_process_patches(self,
320320
321321 """
322322
323- pass
324-
325- def post_process_wsi (self : NucleusDetector ,
323+ def post_process_wsi (
324+ self : NucleusDetector ,
326325 raw_predictions : da .Array ,
327326 prediction_shape : tuple [int , ...],
328327 prediction_dtype : type ,
@@ -340,13 +339,11 @@ def post_process_wsi(self: NucleusDetector,
340339
341340 print ("Chunk size:" , raw_predictions .chunks )
342341
343-
344-
345342 scores = detect_with_map_overlap (
346343 probs = raw_predictions ,
347344 min_distance = 3 ,
348- threshold_abs = 205 , # set your threshold
349- depth_pixels = 5
345+ threshold_abs = 205 , # set your threshold
346+ depth_pixels = 5 ,
350347 )
351348 print ("Scores shape:" , scores .shape )
352349
@@ -362,10 +359,7 @@ def post_process_wsi(self: NucleusDetector,
362359 ann_store = dataframe_to_annotation_store (nms_df )
363360 ann_store .dump (save_path )
364361
365-
366362 sys .exit ()
367-
368-
369363
370364 def run (
371365 self : NucleusDetector ,
@@ -449,5 +443,3 @@ def run(
449443 output_type = output_type ,
450444 ** kwargs ,
451445 )
452-
453-
0 commit comments