Skip to content

Commit 66a5638

Browse files
committed
update docstring
1 parent 780740e commit 66a5638

File tree

1 file changed

+59
-141
lines changed

1 file changed

+59
-141
lines changed

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 59 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
"""Nucleus Detection Engine for Digital Pathology (WSIs and patches).
22
3-
This module implements the `NucleusDetector` classwhich extends
4-
`SemanticSegmentor`to perform instance-level nucleus detection on
3+
This module implements the `NucleusDetector` class which extends
4+
`SemanticSegmentor` to perform instance-level nucleus detection on
55
histology images. It supports patch-mode and whole slide image (WSI)
66
workflows using TIAToolbox or custom PyTorch models, and provides
7-
utilities for halo-aware post-processing (centroid extraction,
8-
thresholding), merging detections across patch seams, and exporting
9-
results in multiple formats (in-memory dict, Zarr, AnnotationStore).
7+
utilities for parallel post-processing (centroid extraction, thresholding),
8+
merging detections across patch, and exporting results in multiple
9+
formats (in-memory dict, Zarr, AnnotationStore).
1010
1111
Classes
1212
-------
1313
NucleusDetectorRunParams
1414
TypedDict specifying runtime configuration keys for detection.
1515
NucleusDetector
16-
Core engine for nucleus detection; orchestrates preprocessing,
17-
inference, post-processing, and output saving.
16+
Core engine for nucleus detection on image patches or WSIs.
1817
1918
Examples:
2019
--------
2120
>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
22-
>>> detector = NucleusDetector(model="mapde-conic", batch_size=16, num_workers=8)
21+
>>> detector = NucleusDetector(model="mapde-conic")
2322
>>> # WSI workflow: save to AnnotationStore (.db)
2423
>>> out = detector.run(
2524
... images=[pathlib.Path("example_wsi.tiff")],
@@ -38,13 +37,10 @@
3837
3938
Notes:
4039
-----
41-
- Consistent with TIAToolbox engines, outputs can be returned as Python
42-
dictionaries, saved as Zarr groups, or converted to AnnotationStore (.db).
43-
- Post-processing uses tile rechunking and halo padding to ensure robust
44-
centroid extraction across chunk boundaries.
45-
- Coordinate scaling for AnnotationStore is derived from the dataloader's
46-
resolution metadata (see base engine); it can be overridden via
47-
`scale_factor` in `NucleusDetectorRunParams`.
40+
- Outputs can be returned as Python dictionaries, saved as Zarr groups,
41+
or converted to AnnotationStore (.db).
42+
- Post-processing uses tile rechunking and halo padding to facilitate
43+
centroid extraction near chunk boundaries.
4844
4945
"""
5046

@@ -116,7 +112,7 @@ class NucleusDetectorRunParams(SemanticSegmentorRunParams, total=False):
116112
Relative detection threshold (e.g., with respect to local maxima).
117113
postproc_tile_shape (tuple[int, int]):
118114
Tile shape (height, width) used during post-processing
119-
(in pixels) to control rechunking and overlap behavior.
115+
(in pixels) to control rechunking behavior.
120116
return_labels (bool):
121117
Whether to return labels with predictions.
122118
return_probabilities (bool):
@@ -250,8 +246,6 @@ def __init__(
250246
pretrained TIAToolbox architecture or as a custom ``torch.nn.Module``.
251247
When ``model`` is a string, the corresponding pretrained weights are
252248
automatically downloaded unless explicitly overridden via ``weights``.
253-
The engine configures batch size, workers, IO settings, and device
254-
placement consistently with the base classes.
255249
256250
Args:
257251
model (str or ModelABC):
@@ -300,63 +294,31 @@ def post_process_patches(
300294
) -> dict:
301295
"""Post-process patch-level detection outputs.
302296
303-
Applies the model's post-processing routine (e.g., centroid extraction and
297+
Applies the model's post-processing function (e.g., centroid extraction and
304298
thresholding) to each patch's probability map, yielding per-patch detection
305-
arrays suitable for saving or further merging. The behavior and parameters
306-
mirror the conventions used in TIAToolbox engines.
299+
arrays suitable for saving or further merging.
307300
308301
Args:
309302
raw_predictions (da.Array):
310303
Patch predictions of shape ``(B, H, W, C)``, where ``B`` is the number
311304
of patches (probabilities/logits).
312305
prediction_shape (tuple[int, ...]):
313-
Expected prediction shape for validation/logging.
306+
Expected prediction shape.
314307
prediction_dtype (type):
315-
Expected prediction dtype for validation/logging.
308+
Expected prediction dtype.
316309
**kwargs (NucleusDetectorRunParams):
317310
Additional runtime parameters to configure segmentation.
318311
319312
Optional Keys:
320-
auto_get_mask (bool):
321-
Whether to automatically generate segmentation masks using
322-
`wsireader.tissue_mask()` during processing.
323-
batch_size (int):
324-
Number of image patches to feed to the model in a forward pass.
325-
class_dict (dict):
326-
Optional dictionary mapping classification outputs to
327-
class names.
328-
device (str):
329-
Device to run the model on (e.g., "cpu", "cuda").
330-
labels (list):
331-
Optional labels for input images. Only a single label per image
332-
is supported.
333-
memory_threshold (int):
334-
Memory usage threshold (in percentage) to
335-
trigger caching behavior.
336-
num_workers (int):
337-
Number of workers used in DataLoader.
338-
output_file (str):
339-
Output file name for saving results (e.g., .zarr or .db).
340-
output_resolutions (Resolution):
341-
Resolution used for writing output predictions.
342-
patch_output_shape (tuple[int, int]):
343-
Shape of output patches (height, width).
344313
min_distance (int):
345-
Minimum distance separating two nuclei (in pixels).
346-
postproc_tile_shape (tuple[int, int]):
347-
Tile shape (height, width) for post-processing (in pixels).
348-
return_labels (bool):
349-
Whether to return labels with predictions.
350-
return_probabilities (bool):
351-
Whether to return per-class probabilities.
352-
scale_factor (tuple[float, float]):
353-
Scale factor for converting annotations to baseline resolution.
354-
Typically model_mpp / slide_mpp.
355-
stride_shape (tuple[int, int]):
356-
Stride used during WSI processing.
357-
Defaults to patch_input_shape.
358-
verbose (bool):
359-
Whether to output logging information.
314+
Minimum separation between nuclei (in pixels) used during
315+
centroid extraction/post-processing.
316+
threshold_abs (float):
317+
Absolute detection threshold applied to model outputs.
318+
threshold_rel (float):
319+
Relative detection threshold
320+
(e.g., with respect to local maxima).
321+
360322
361323
Returns:
362324
dict[str, list[da.Array]]:
@@ -372,8 +334,6 @@ class names.
372334
373335
Notes:
374336
- If thresholds are not provided via ``kwargs``, model defaults are used.
375-
- The output structure intentionally mirrors other TIAToolbox engines,
376-
enabling downstream saving as ``dict``, ``zarr``, or ``annotationstore``.
377337
378338
"""
379339
logger.info("Post processing patch predictions in NucleusDetector")
@@ -423,63 +383,29 @@ def post_process_wsi(
423383
Processes the full-slide prediction map using Dask's block-wise operations
424384
to extract nuclei centroids across the entire WSI. The prediction map is
425385
first re-chunked to the model's preferred post-processing tile shape, and
426-
`dask.map_overlap` is used to ensure accurate centroid extraction along
427-
chunk boundaries via halo padding. The resulting centroid maps are then
428-
converted into final detection arrays (x, y, classes, probabilities).
386+
`dask.map_overlap` with halo padding is used to facilitate centroid
387+
extraction on large prediction maps. The resulting centroid maps are then
388+
converted into detection arrays (x, y, classes, probabilities).
429389
430390
Args:
431391
raw_predictions (da.Array):
432392
WSI prediction map of shape ``(H, W, C)`` containing
433393
per-class probabilities or logits.
434394
prediction_shape (tuple[int, ...]):
435-
Expected prediction shape (provided for consistency with the
436-
base engine interface).
395+
Expected prediction shape.
437396
prediction_dtype (type):
438-
Expected prediction dtype (also provided for consistency).
397+
Expected prediction dtype.
439398
**kwargs (NucleusDetectorRunParams):
440399
Additional runtime parameters to configure segmentation.
441400
442401
Optional Keys:
443-
auto_get_mask (bool):
444-
Whether to automatically generate segmentation masks using
445-
`wsireader.tissue_mask()` during processing.
446-
batch_size (int):
447-
Number of image patches to feed to the model in a forward pass.
448-
class_dict (dict):
449-
Optional dictionary mapping classification outputs to
450-
class names.
451-
device (str):
452-
Device to run the model on (e.g., "cpu", "cuda").
453-
labels (list):
454-
Optional labels for input images. Only a single label per image
455-
is supported.
456-
memory_threshold (int):
457-
Memory usage threshold (in percentage) to
458-
trigger caching behavior.
459-
num_workers (int):
460-
Number of workers used in DataLoader.
461-
output_file (str):
462-
Output file name for saving results (e.g., .zarr or .db).
463-
output_resolutions (Resolution):
464-
Resolution used for writing output predictions.
465-
patch_output_shape (tuple[int, int]):
466-
Shape of output patches (height, width).
467402
min_distance (int):
468403
Minimum distance separating two nuclei (in pixels).
469-
postproc_tile_shape (tuple[int, int]):
470-
Tile shape (height, width) for post-processing (in pixels).
471-
return_labels (bool):
472-
Whether to return labels with predictions.
473-
return_probabilities (bool):
474-
Whether to return per-class probabilities.
475-
scale_factor (tuple[float, float]):
476-
Scale factor for converting annotations to baseline resolution.
477-
Typically model_mpp / slide_mpp.
478-
stride_shape (tuple[int, int]):
479-
Stride used during WSI processing.
480-
Defaults to patch_input_shape.
481-
verbose (bool):
482-
Whether to output logging information.
404+
threshold_abs (float):
405+
Absolute detection threshold applied to model outputs.
406+
threshold_rel (float):
407+
Relative detection threshold
408+
(e.g., with respect to local maxima).
483409
484410
Returns:
485411
dict[str, da.Array]:
@@ -493,8 +419,6 @@ class names.
493419
- Halo padding ensures that nuclei crossing tile/chunk boundaries
494420
are not fragmented or duplicated.
495421
- If thresholds are not explicitly provided, model defaults are used.
496-
- The output structure matches TIAToolbox conventions so it can be
497-
saved directly as a ``dict``, ``zarr`` group, or ``annotationstore``.
498422
499423
"""
500424
_ = prediction_shape
@@ -643,8 +567,6 @@ class names.
643567
- For non-AnnotationStore outputs, this method delegates to the
644568
base engine's saving function to preserve consistency across
645569
TIAToolbox engines.
646-
- Coordinate scaling and class name mapping are applied only when saving
647-
to AnnotationStore and when provided via ``kwargs``.
648570
649571
"""
650572
if output_type.lower() != "annotationstore":
@@ -722,8 +644,6 @@ def _save_predictions_annotation_store(
722644
- This method centralizes the translation of detection arrays into
723645
`Annotation` objects and abstracts batching logic via
724646
``_write_detection_arrays_to_store``.
725-
- When writing to disk, the resulting file always uses a ``.db`` suffix,
726-
consistent with other TIAToolbox engines.
727647
728648
"""
729649
logger.info("Saving predictions as AnnotationStore.")
@@ -771,8 +691,7 @@ def _centroid_maps_to_detection_arrays(
771691
This helper function extracts non-zero centroid predictions from a
772692
Dask array of centroid maps and flattens them into coordinate,
773693
class, and probability arrays suitable for saving or further
774-
processing. The output format mirrors the detection structure used
775-
throughout TIAToolbox engines.
694+
processing.
776695
777696
Args:
778697
detection_maps (da.Array):
@@ -822,7 +741,7 @@ def _write_detection_arrays_to_store(
822741
"""Write detection arrays to an AnnotationStore in batches.
823742
824743
Converts coordinate, class, and probability arrays into `Annotation`
825-
records and appends them to an SQLite-backed store in configurable
744+
objects and appends them to an SQLite-backed store in configurable
826745
batch sizes. Coordinates are scaled to baseline slide resolution using
827746
the provided `scale_factor`, and optional class-ID remapping is applied
828747
via `class_dict`.
@@ -854,7 +773,7 @@ def _write_detection_arrays_to_store(
854773
- Class mapping is applied per-record; unmapped IDs fall back to their
855774
original values.
856775
- Writing in batches reduces memory pressure and improves throughput
857-
on large detections.
776+
on large number of detections.
858777
859778
"""
860779
xs, ys, classes, probs = detection_arrays
@@ -949,13 +868,9 @@ class IDs.
949868
all detections.
950869
951870
Notes:
952-
- All detection arrays are converted to NumPy using ``np.asarray`` to
953-
ensure consistent batching.
954871
- The heavy lifting is delegated to
955872
:meth:`NucleusDetector._write_detection_arrays_to_store`,
956873
which performs coordinate scaling, class mapping, and batch writing.
957-
- The database always uses the ``.db`` suffix for consistency with
958-
TIAToolbox's AnnotationStore format TIAToolbox's AnnotationStore format.
959874
960875
"""
961876
xs = detection_arrays["x"]
@@ -1067,6 +982,11 @@ class names.
1067982
Shape of output patches (height, width).
1068983
min_distance (int):
1069984
Minimum distance separating two nuclei (in pixels).
985+
threshold_abs (float):
986+
Absolute detection threshold applied to model outputs.
987+
threshold_rel (float):
988+
Relative detection threshold
989+
(e.g., with respect to local maxima).
1070990
postproc_tile_shape (tuple[int, int]):
1071991
Tile shape (height, width) for post-processing (in pixels).
1072992
return_labels (bool):
@@ -1089,26 +1009,24 @@ class names.
10891009
to its output path.
10901010
10911011
Examples:
1092-
>>> wsis = ['wsi1.svs', 'wsi2.svs']
1093-
>>> image_patches = [np.ndarray, np.ndarray]
1094-
>>> nuc_detector = NucleusDetector(model="sccnn-conic")
1095-
>>> output = nuc_detector.run(image_patches, patch_mode=True)
1096-
>>> output
1097-
... "/path/to/Output.db"
1098-
1099-
>>> output = nuc_detector.run(
1100-
... image_patches,
1101-
... patch_mode=True,
1102-
... output_type="zarr"
1012+
>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
1013+
>>> detector = NucleusDetector(model="mapde-conic")
1014+
>>> # WSI workflow: save to AnnotationStore (.db)
1015+
>>> out = detector.run(
1016+
... images=[pathlib.Path("example_wsi.tiff")],
1017+
... patch_mode=False,
1018+
... device="cuda",
1019+
... save_dir=pathlib.Path("output_directory/"),
1020+
... overwrite=True,
1021+
... output_type="annotationstore",
1022+
... class_dict={0: "nucleus"},
1023+
... auto_get_mask=True,
1024+
... memory_threshold=80,
11031025
... )
1104-
>>> output
1105-
... "/path/to/Output.zarr"
1106-
1107-
>>> output = nuc_detector.run(wsis, patch_mode=False)
1108-
>>> output.keys()
1109-
... ['wsi1.svs', 'wsi2.svs']
1110-
>>> output['wsi1.svs']
1111-
... "/path/to/wsi1.db"
1026+
>>> # Patch workflow: return in-memory detections
1027+
>>> patches = [np.ndarray, np.ndarray] # NHWC
1028+
>>> out = detector.run(patches, patch_mode=True, output_type="dict")
1029+
11121030
11131031
"""
11141032
return super().run(

0 commit comments

Comments
 (0)