11"""Nucleus Detection Engine for Digital Pathology (WSIs and patches).
22
3- This module implements the `NucleusDetector` class— which extends
4- `SemanticSegmentor`— to perform instance-level nucleus detection on
3+ This module implements the `NucleusDetector` class which extends
4+ `SemanticSegmentor` to perform instance-level nucleus detection on
55histology images. It supports patch-mode and whole slide image (WSI)
66workflows using TIAToolbox or custom PyTorch models, and provides
7- utilities for halo-aware post-processing (centroid extraction,
8- thresholding), merging detections across patch seams , and exporting
9- results in multiple formats (in-memory dict, Zarr, AnnotationStore).
7+ utilities for parallel post-processing (centroid extraction, thresholding) ,
8+ merging detections across patch, and exporting results in multiple
9+ formats (in-memory dict, Zarr, AnnotationStore).
1010
1111Classes
1212-------
1313NucleusDetectorRunParams
1414 TypedDict specifying runtime configuration keys for detection.
1515NucleusDetector
16- Core engine for nucleus detection; orchestrates preprocessing,
17- inference, post-processing, and output saving.
16+ Core engine for nucleus detection on image patches or WSIs.
1817
1918Examples:
2019--------
2120>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
22- >>> detector = NucleusDetector(model="mapde-conic", batch_size=16, num_workers=8 )
21+ >>> detector = NucleusDetector(model="mapde-conic")
2322>>> # WSI workflow: save to AnnotationStore (.db)
2423>>> out = detector.run(
2524... images=[pathlib.Path("example_wsi.tiff")],
3837
3938Notes:
4039-----
41- - Consistent with TIAToolbox engines, outputs can be returned as Python
42- dictionaries, saved as Zarr groups, or converted to AnnotationStore (.db).
43- - Post-processing uses tile rechunking and halo padding to ensure robust
44- centroid extraction across chunk boundaries.
45- - Coordinate scaling for AnnotationStore is derived from the dataloader's
46- resolution metadata (see base engine); it can be overridden via
47- `scale_factor` in `NucleusDetectorRunParams`.
40+ - Outputs can be returned as Python dictionaries, saved as Zarr groups,
41+ or converted to AnnotationStore (.db).
42+ - Post-processing uses tile rechunking and halo padding to facilitate
43+ centroid extraction near chunk boundaries.
4844
4945"""
5046
@@ -116,7 +112,7 @@ class NucleusDetectorRunParams(SemanticSegmentorRunParams, total=False):
116112 Relative detection threshold (e.g., with respect to local maxima).
117113 postproc_tile_shape (tuple[int, int]):
118114 Tile shape (height, width) used during post-processing
119- (in pixels) to control rechunking and overlap behavior.
115+ (in pixels) to control rechunking behavior.
120116 return_labels (bool):
121117 Whether to return labels with predictions.
122118 return_probabilities (bool):
@@ -250,8 +246,6 @@ def __init__(
250246 pretrained TIAToolbox architecture or as a custom ``torch.nn.Module``.
251247 When ``model`` is a string, the corresponding pretrained weights are
252248 automatically downloaded unless explicitly overridden via ``weights``.
253- The engine configures batch size, workers, IO settings, and device
254- placement consistently with the base classes.
255249
256250 Args:
257251 model (str or ModelABC):
@@ -300,63 +294,31 @@ def post_process_patches(
300294 ) -> dict :
301295 """Post-process patch-level detection outputs.
302296
303- Applies the model's post-processing routine (e.g., centroid extraction and
297+ Applies the model's post-processing function (e.g., centroid extraction and
304298 thresholding) to each patch's probability map, yielding per-patch detection
305- arrays suitable for saving or further merging. The behavior and parameters
306- mirror the conventions used in TIAToolbox engines.
299+ arrays suitable for saving or further merging.
307300
308301 Args:
309302 raw_predictions (da.Array):
310303 Patch predictions of shape ``(B, H, W, C)``, where ``B`` is the number
311304 of patches (probabilities/logits).
312305 prediction_shape (tuple[int, ...]):
313- Expected prediction shape for validation/logging .
306+ Expected prediction shape.
314307 prediction_dtype (type):
315- Expected prediction dtype for validation/logging .
308+ Expected prediction dtype.
316309 **kwargs (NucleusDetectorRunParams):
317310 Additional runtime parameters to configure segmentation.
318311
319312 Optional Keys:
320- auto_get_mask (bool):
321- Whether to automatically generate segmentation masks using
322- `wsireader.tissue_mask()` during processing.
323- batch_size (int):
324- Number of image patches to feed to the model in a forward pass.
325- class_dict (dict):
326- Optional dictionary mapping classification outputs to
327- class names.
328- device (str):
329- Device to run the model on (e.g., "cpu", "cuda").
330- labels (list):
331- Optional labels for input images. Only a single label per image
332- is supported.
333- memory_threshold (int):
334- Memory usage threshold (in percentage) to
335- trigger caching behavior.
336- num_workers (int):
337- Number of workers used in DataLoader.
338- output_file (str):
339- Output file name for saving results (e.g., .zarr or .db).
340- output_resolutions (Resolution):
341- Resolution used for writing output predictions.
342- patch_output_shape (tuple[int, int]):
343- Shape of output patches (height, width).
344313 min_distance (int):
345- Minimum distance separating two nuclei (in pixels).
346- postproc_tile_shape (tuple[int, int]):
347- Tile shape (height, width) for post-processing (in pixels).
348- return_labels (bool):
349- Whether to return labels with predictions.
350- return_probabilities (bool):
351- Whether to return per-class probabilities.
352- scale_factor (tuple[float, float]):
353- Scale factor for converting annotations to baseline resolution.
354- Typically model_mpp / slide_mpp.
355- stride_shape (tuple[int, int]):
356- Stride used during WSI processing.
357- Defaults to patch_input_shape.
358- verbose (bool):
359- Whether to output logging information.
314+ Minimum separation between nuclei (in pixels) used during
315+ centroid extraction/post-processing.
316+ threshold_abs (float):
317+ Absolute detection threshold applied to model outputs.
318+ threshold_rel (float):
319+ Relative detection threshold
320+ (e.g., with respect to local maxima).
321+
360322
361323 Returns:
362324 dict[str, list[da.Array]]:
@@ -372,8 +334,6 @@ class names.
372334
373335 Notes:
374336 - If thresholds are not provided via ``kwargs``, model defaults are used.
375- - The output structure intentionally mirrors other TIAToolbox engines,
376- enabling downstream saving as ``dict``, ``zarr``, or ``annotationstore``.
377337
378338 """
379339 logger .info ("Post processing patch predictions in NucleusDetector" )
@@ -423,63 +383,29 @@ def post_process_wsi(
423383 Processes the full-slide prediction map using Dask's block-wise operations
424384 to extract nuclei centroids across the entire WSI. The prediction map is
425385 first re-chunked to the model's preferred post-processing tile shape, and
426- `dask.map_overlap` is used to ensure accurate centroid extraction along
427- chunk boundaries via halo padding . The resulting centroid maps are then
428- converted into final detection arrays (x, y, classes, probabilities).
386+ `dask.map_overlap` with halo padding is used to facilitate centroid
387+ extraction on large prediction maps . The resulting centroid maps are then
388+ converted into detection arrays (x, y, classes, probabilities).
429389
430390 Args:
431391 raw_predictions (da.Array):
432392 WSI prediction map of shape ``(H, W, C)`` containing
433393 per-class probabilities or logits.
434394 prediction_shape (tuple[int, ...]):
435- Expected prediction shape (provided for consistency with the
436- base engine interface).
395+ Expected prediction shape.
437396 prediction_dtype (type):
438- Expected prediction dtype (also provided for consistency) .
397+ Expected prediction dtype.
439398 **kwargs (NucleusDetectorRunParams):
440399 Additional runtime parameters to configure segmentation.
441400
442401 Optional Keys:
443- auto_get_mask (bool):
444- Whether to automatically generate segmentation masks using
445- `wsireader.tissue_mask()` during processing.
446- batch_size (int):
447- Number of image patches to feed to the model in a forward pass.
448- class_dict (dict):
449- Optional dictionary mapping classification outputs to
450- class names.
451- device (str):
452- Device to run the model on (e.g., "cpu", "cuda").
453- labels (list):
454- Optional labels for input images. Only a single label per image
455- is supported.
456- memory_threshold (int):
457- Memory usage threshold (in percentage) to
458- trigger caching behavior.
459- num_workers (int):
460- Number of workers used in DataLoader.
461- output_file (str):
462- Output file name for saving results (e.g., .zarr or .db).
463- output_resolutions (Resolution):
464- Resolution used for writing output predictions.
465- patch_output_shape (tuple[int, int]):
466- Shape of output patches (height, width).
467402 min_distance (int):
468403 Minimum distance separating two nuclei (in pixels).
469- postproc_tile_shape (tuple[int, int]):
470- Tile shape (height, width) for post-processing (in pixels).
471- return_labels (bool):
472- Whether to return labels with predictions.
473- return_probabilities (bool):
474- Whether to return per-class probabilities.
475- scale_factor (tuple[float, float]):
476- Scale factor for converting annotations to baseline resolution.
477- Typically model_mpp / slide_mpp.
478- stride_shape (tuple[int, int]):
479- Stride used during WSI processing.
480- Defaults to patch_input_shape.
481- verbose (bool):
482- Whether to output logging information.
404+ threshold_abs (float):
405+ Absolute detection threshold applied to model outputs.
406+ threshold_rel (float):
407+ Relative detection threshold
408+ (e.g., with respect to local maxima).
483409
484410 Returns:
485411 dict[str, da.Array]:
@@ -493,8 +419,6 @@ class names.
493419 - Halo padding ensures that nuclei crossing tile/chunk boundaries
494420 are not fragmented or duplicated.
495421 - If thresholds are not explicitly provided, model defaults are used.
496- - The output structure matches TIAToolbox conventions so it can be
497- saved directly as a ``dict``, ``zarr`` group, or ``annotationstore``.
498422
499423 """
500424 _ = prediction_shape
@@ -643,8 +567,6 @@ class names.
643567 - For non-AnnotationStore outputs, this method delegates to the
644568 base engine's saving function to preserve consistency across
645569 TIAToolbox engines.
646- - Coordinate scaling and class name mapping are applied only when saving
647- to AnnotationStore and when provided via ``kwargs``.
648570
649571 """
650572 if output_type .lower () != "annotationstore" :
@@ -722,8 +644,6 @@ def _save_predictions_annotation_store(
722644 - This method centralizes the translation of detection arrays into
723645 `Annotation` objects and abstracts batching logic via
724646 ``_write_detection_arrays_to_store``.
725- - When writing to disk, the resulting file always uses a ``.db`` suffix,
726- consistent with other TIAToolbox engines.
727647
728648 """
729649 logger .info ("Saving predictions as AnnotationStore." )
@@ -771,8 +691,7 @@ def _centroid_maps_to_detection_arrays(
771691 This helper function extracts non-zero centroid predictions from a
772692 Dask array of centroid maps and flattens them into coordinate,
773693 class, and probability arrays suitable for saving or further
774- processing. The output format mirrors the detection structure used
775- throughout TIAToolbox engines.
694+ processing.
776695
777696 Args:
778697 detection_maps (da.Array):
@@ -822,7 +741,7 @@ def _write_detection_arrays_to_store(
822741 """Write detection arrays to an AnnotationStore in batches.
823742
824743 Converts coordinate, class, and probability arrays into `Annotation`
825- records and appends them to an SQLite-backed store in configurable
744+ objects and appends them to an SQLite-backed store in configurable
826745 batch sizes. Coordinates are scaled to baseline slide resolution using
827746 the provided `scale_factor`, and optional class-ID remapping is applied
828747 via `class_dict`.
@@ -854,7 +773,7 @@ def _write_detection_arrays_to_store(
854773 - Class mapping is applied per-record; unmapped IDs fall back to their
855774 original values.
856775 - Writing in batches reduces memory pressure and improves throughput
857- on large detections.
776+ on large number of detections.
858777
859778 """
860779 xs , ys , classes , probs = detection_arrays
@@ -949,13 +868,9 @@ class IDs.
949868 all detections.
950869
951870 Notes:
952- - All detection arrays are converted to NumPy using ``np.asarray`` to
953- ensure consistent batching.
954871 - The heavy lifting is delegated to
955872 :meth:`NucleusDetector._write_detection_arrays_to_store`,
956873 which performs coordinate scaling, class mapping, and batch writing.
957- - The database always uses the ``.db`` suffix for consistency with
958- TIAToolbox's AnnotationStore format TIAToolbox's AnnotationStore format.
959874
960875 """
961876 xs = detection_arrays ["x" ]
@@ -1067,6 +982,11 @@ class names.
1067982 Shape of output patches (height, width).
1068983 min_distance (int):
1069984 Minimum distance separating two nuclei (in pixels).
985+ threshold_abs (float):
986+ Absolute detection threshold applied to model outputs.
987+ threshold_rel (float):
988+ Relative detection threshold
989+ (e.g., with respect to local maxima).
1070990 postproc_tile_shape (tuple[int, int]):
1071991 Tile shape (height, width) for post-processing (in pixels).
1072992 return_labels (bool):
@@ -1089,26 +1009,24 @@ class names.
10891009 to its output path.
10901010
10911011 Examples:
1092- >>> wsis = ['wsi1.svs', 'wsi2.svs']
1093- >>> image_patches = [np.ndarray, np.ndarray]
1094- >>> nuc_detector = NucleusDetector(model="sccnn-conic")
1095- >>> output = nuc_detector.run(image_patches, patch_mode=True)
1096- >>> output
1097- ... "/path/to/Output.db"
1098-
1099- >>> output = nuc_detector.run(
1100- ... image_patches,
1101- ... patch_mode=True,
1102- ... output_type="zarr"
1012+ >>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
1013+ >>> detector = NucleusDetector(model="mapde-conic")
1014+ >>> # WSI workflow: save to AnnotationStore (.db)
1015+ >>> out = detector.run(
1016+ ... images=[pathlib.Path("example_wsi.tiff")],
1017+ ... patch_mode=False,
1018+ ... device="cuda",
1019+ ... save_dir=pathlib.Path("output_directory/"),
1020+ ... overwrite=True,
1021+ ... output_type="annotationstore",
1022+ ... class_dict={0: "nucleus"},
1023+ ... auto_get_mask=True,
1024+ ... memory_threshold=80,
11031025 ... )
1104- >>> output
1105- ... "/path/to/Output.zarr"
1106-
1107- >>> output = nuc_detector.run(wsis, patch_mode=False)
1108- >>> output.keys()
1109- ... ['wsi1.svs', 'wsi2.svs']
1110- >>> output['wsi1.svs']
1111- ... "/path/to/wsi1.db"
1026+ >>> # Patch workflow: return in-memory detections
1027+ >>> patches = [np.ndarray, np.ndarray] # NHWC
1028+ >>> out = detector.run(patches, patch_mode=True, output_type="dict")
1029+
11121030
11131031 """
11141032 return super ().run (
0 commit comments