Skip to content

Commit 5ac5fd0

Browse files
committed
🚧 Add __init__
1 parent 36cd0e9 commit 5ac5fd0

File tree

1 file changed

+164
-125
lines changed

1 file changed

+164
-125
lines changed

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 164 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Unpack
2323

2424
from tiatoolbox.annotation import AnnotationStore
25+
from tiatoolbox.models.models_abc import ModelABC
2526
from tiatoolbox.type_hints import IntPair, Resolution, Units
2627
from tiatoolbox.wsicore import WSIReader
2728

@@ -139,137 +140,42 @@ class NucleusDetector(SemanticSegmentor):
139140
140141
"""
141142

142-
def run(
143+
def __init__(
143144
self: NucleusDetector,
144-
images: list[os.PathLike | Path | WSIReader] | np.ndarray,
145+
model: str | ModelABC,
146+
batch_size: int = 8,
147+
num_workers: int = 0,
148+
weights: str | Path | None = None,
145149
*,
146-
masks: list[os.PathLike | Path] | np.ndarray | None = None,
147-
input_resolutions: list[dict[Units, Resolution]] | None = None,
148-
patch_input_shape: IntPair | None = None,
149-
ioconfig: IOSegmentorConfig | None = None,
150-
patch_mode: bool = True,
151-
save_dir: os.PathLike | Path | None = None,
152-
overwrite: bool = False,
153-
output_type: str = "dict",
154-
**kwargs: Unpack[NucleusDetectorRunParams],
155-
) -> AnnotationStore | Path | str | dict | list[Path]:
156-
"""Run the nucleus detection engine on input images.
157-
158-
This method orchestrates the full inference pipeline, including preprocessing,
159-
model inference, post-processing, and saving results. It supports both
160-
patch-level and whole slide image (WSI) modes.
150+
device: str = "cpu",
151+
verbose: bool = True,
152+
) -> None:
153+
"""Initialize :class:`NucleusDetector`.
161154
162155
Args:
163-
images (list[PathLike | WSIReader] | np.ndarray):
164-
Input images or patches. Can be a list of file paths, WSIReader objects,
165-
or a NumPy array of image patches.
166-
masks (list[PathLike] | np.ndarray | None):
167-
Optional masks for WSI processing. Only used when `patch_mode` is False.
168-
input_resolutions (list[dict[Units, Resolution]] | None):
169-
Resolution settings for input heads. Supported units are `level`,
170-
`power`, and `mpp`. Keys should be "units" and "resolution", e.g.,
171-
[{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for
172-
details.
173-
patch_input_shape (IntPair | None):
174-
Shape of input patches (height, width), requested at read
175-
resolution. Must be positive.
176-
ioconfig (IOSegmentorConfig | None):
177-
IO configuration for patch extraction and resolution.
178-
patch_mode (bool):
179-
Whether to treat input as patches (`True`) or WSIs (`False`). Default
180-
is True.
181-
save_dir (PathLike | None):
182-
Directory to save output files. Required for WSI mode.
183-
overwrite (bool):
184-
Whether to overwrite existing output files. Default is False.
185-
output_type (str):
186-
Desired output format: "dict", "zarr", or "annotationstore". Default
187-
is "dict".
188-
**kwargs (NucleusDetectorRunParams):
189-
Additional runtime parameters to configure segmentation.
190-
191-
Optional Keys:
192-
auto_get_mask (bool):
193-
Whether to automatically generate segmentation masks using
194-
`wsireader.tissue_mask()` during processing.
195-
batch_size (int):
196-
Number of image patches to feed to the model in a forward pass.
197-
class_dict (dict):
198-
Optional dictionary mapping classification outputs to
199-
class names.
200-
device (str):
201-
Device to run the model on (e.g., "cpu", "cuda").
202-
labels (list):
203-
Optional labels for input images. Only a single label per image
204-
is supported.
205-
memory_threshold (int):
206-
Memory usage threshold (in percentage) to
207-
trigger caching behavior.
208-
num_workers (int):
209-
Number of workers used in DataLoader.
210-
output_file (str):
211-
Output file name for saving results (e.g., .zarr or .db).
212-
output_resolutions (Resolution):
213-
Resolution used for writing output predictions.
214-
patch_output_shape (tuple[int, int]):
215-
Shape of output patches (height, width).
216-
min_distance (int):
217-
Minimum distance separating two nuclei (in pixels).
218-
postproc_tile_shape (tuple[int, int]):
219-
Tile shape (height, width) for post-processing (in pixels).
220-
return_labels (bool):
221-
Whether to return labels with predictions.
222-
return_probabilities (bool):
223-
Whether to return per-class probabilities.
224-
scale_factor (tuple[float, float]):
225-
Scale factor for converting annotations to baseline resolution.
226-
Typically model_mpp / slide_mpp.
227-
stride_shape (tuple[int, int]):
228-
Stride used during WSI processing.
229-
Defaults to patch_input_shape.
230-
verbose (bool):
231-
Whether to output logging information.
232-
233-
Returns:
234-
AnnotationStore | Path | str | dict | list[Path]:
235-
- If `patch_mode` is True: returns predictions or path to saved output.
236-
- If `patch_mode` is False: returns a dictionary mapping each WSI
237-
to its output path.
238-
239-
Examples:
240-
>>> wsis = ['wsi1.svs', 'wsi2.svs']
241-
>>> image_patches = [np.ndarray, np.ndarray]
242-
>>> nuc_detector = NucleusDetector(model="sccnn-conic")
243-
>>> output = nuc_detector.run(image_patches, patch_mode=True)
244-
>>> output
245-
... "/path/to/Output.db"
246-
247-
>>> output = nuc_detector.run(
248-
... image_patches,
249-
... patch_mode=True,
250-
... output_type="zarr"
251-
... )
252-
>>> output
253-
... "/path/to/Output.zarr"
254-
255-
>>> output = nuc_detector.run(wsis, patch_mode=False)
256-
>>> output.keys()
257-
... ['wsi1.svs', 'wsi2.svs']
258-
>>> output['wsi1.svs']
259-
... "/path/to/wsi1.db"
156+
model (str | ModelABC):
157+
A PyTorch model instance or name of a pretrained model from TIAToolbox.
158+
If a string is provided, the corresponding pretrained weights will be
159+
downloaded unless overridden via `weights`.
160+
batch_size (int):
161+
Number of image patches processed per forward pass. Default is 8.
162+
num_workers (int):
163+
Number of workers for data loading. Default is 0.
164+
weights (str | Path | None):
165+
Path to model weights. If None, default weights are used.
166+
device (str):
167+
Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu".
168+
verbose (bool):
169+
Whether to enable verbose logging. Default is True.
260170
261171
"""
262-
return super().run(
263-
images=images,
264-
masks=masks,
265-
input_resolutions=input_resolutions,
266-
patch_input_shape=patch_input_shape,
267-
ioconfig=ioconfig,
268-
patch_mode=patch_mode,
269-
save_dir=save_dir,
270-
overwrite=overwrite,
271-
output_type=output_type,
272-
**kwargs,
172+
super().__init__(
173+
model=model,
174+
batch_size=batch_size,
175+
num_workers=num_workers,
176+
weights=weights,
177+
device=device,
178+
verbose=verbose,
273179
)
274180

275181
def post_process_patches(
@@ -731,3 +637,136 @@ def save_detection_arrays_to_store(
731637
return save_path
732638

733639
return store
640+
641+
def run(
642+
self: NucleusDetector,
643+
images: list[os.PathLike | Path | WSIReader] | np.ndarray,
644+
*,
645+
masks: list[os.PathLike | Path] | np.ndarray | None = None,
646+
input_resolutions: list[dict[Units, Resolution]] | None = None,
647+
patch_input_shape: IntPair | None = None,
648+
ioconfig: IOSegmentorConfig | None = None,
649+
patch_mode: bool = True,
650+
save_dir: os.PathLike | Path | None = None,
651+
overwrite: bool = False,
652+
output_type: str = "dict",
653+
**kwargs: Unpack[NucleusDetectorRunParams],
654+
) -> AnnotationStore | Path | str | dict | list[Path]:
655+
"""Run the nucleus detection engine on input images.
656+
657+
This method orchestrates the full inference pipeline, including preprocessing,
658+
model inference, post-processing, and saving results. It supports both
659+
patch-level and whole slide image (WSI) modes.
660+
661+
Args:
662+
images (list[PathLike | WSIReader] | np.ndarray):
663+
Input images or patches. Can be a list of file paths, WSIReader objects,
664+
or a NumPy array of image patches.
665+
masks (list[PathLike] | np.ndarray | None):
666+
Optional masks for WSI processing. Only used when `patch_mode` is False.
667+
input_resolutions (list[dict[Units, Resolution]] | None):
668+
Resolution settings for input heads. Supported units are `level`,
669+
`power`, and `mpp`. Keys should be "units" and "resolution", e.g.,
670+
[{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for
671+
details.
672+
patch_input_shape (IntPair | None):
673+
Shape of input patches (height, width), requested at read
674+
resolution. Must be positive.
675+
ioconfig (IOSegmentorConfig | None):
676+
IO configuration for patch extraction and resolution.
677+
patch_mode (bool):
678+
Whether to treat input as patches (`True`) or WSIs (`False`). Default
679+
is True.
680+
save_dir (PathLike | None):
681+
Directory to save output files. Required for WSI mode.
682+
overwrite (bool):
683+
Whether to overwrite existing output files. Default is False.
684+
output_type (str):
685+
Desired output format: "dict", "zarr", or "annotationstore". Default
686+
is "dict".
687+
**kwargs (NucleusDetectorRunParams):
688+
Additional runtime parameters to configure segmentation.
689+
690+
Optional Keys:
691+
auto_get_mask (bool):
692+
Whether to automatically generate segmentation masks using
693+
`wsireader.tissue_mask()` during processing.
694+
batch_size (int):
695+
Number of image patches to feed to the model in a forward pass.
696+
class_dict (dict):
697+
Optional dictionary mapping classification outputs to
698+
class names.
699+
device (str):
700+
Device to run the model on (e.g., "cpu", "cuda").
701+
labels (list):
702+
Optional labels for input images. Only a single label per image
703+
is supported.
704+
memory_threshold (int):
705+
Memory usage threshold (in percentage) to
706+
trigger caching behavior.
707+
num_workers (int):
708+
Number of workers used in DataLoader.
709+
output_file (str):
710+
Output file name for saving results (e.g., .zarr or .db).
711+
output_resolutions (Resolution):
712+
Resolution used for writing output predictions.
713+
patch_output_shape (tuple[int, int]):
714+
Shape of output patches (height, width).
715+
min_distance (int):
716+
Minimum distance separating two nuclei (in pixels).
717+
postproc_tile_shape (tuple[int, int]):
718+
Tile shape (height, width) for post-processing (in pixels).
719+
return_labels (bool):
720+
Whether to return labels with predictions.
721+
return_probabilities (bool):
722+
Whether to return per-class probabilities.
723+
scale_factor (tuple[float, float]):
724+
Scale factor for converting annotations to baseline resolution.
725+
Typically model_mpp / slide_mpp.
726+
stride_shape (tuple[int, int]):
727+
Stride used during WSI processing.
728+
Defaults to patch_input_shape.
729+
verbose (bool):
730+
Whether to output logging information.
731+
732+
Returns:
733+
AnnotationStore | Path | str | dict | list[Path]:
734+
- If `patch_mode` is True: returns predictions or path to saved output.
735+
- If `patch_mode` is False: returns a dictionary mapping each WSI
736+
to its output path.
737+
738+
Examples:
739+
>>> wsis = ['wsi1.svs', 'wsi2.svs']
740+
>>> image_patches = [np.ndarray, np.ndarray]
741+
>>> nuc_detector = NucleusDetector(model="sccnn-conic")
742+
>>> output = nuc_detector.run(image_patches, patch_mode=True)
743+
>>> output
744+
... "/path/to/Output.db"
745+
746+
>>> output = nuc_detector.run(
747+
... image_patches,
748+
... patch_mode=True,
749+
... output_type="zarr"
750+
... )
751+
>>> output
752+
... "/path/to/Output.zarr"
753+
754+
>>> output = nuc_detector.run(wsis, patch_mode=False)
755+
>>> output.keys()
756+
... ['wsi1.svs', 'wsi2.svs']
757+
>>> output['wsi1.svs']
758+
... "/path/to/wsi1.db"
759+
760+
"""
761+
return super().run(
762+
images=images,
763+
masks=masks,
764+
input_resolutions=input_resolutions,
765+
patch_input_shape=patch_input_shape,
766+
ioconfig=ioconfig,
767+
patch_mode=patch_mode,
768+
save_dir=save_dir,
769+
overwrite=overwrite,
770+
output_type=output_type,
771+
**kwargs,
772+
)

0 commit comments

Comments
 (0)