|
22 | 22 | from typing import Unpack |
23 | 23 |
|
24 | 24 | from tiatoolbox.annotation import AnnotationStore |
| 25 | + from tiatoolbox.models.models_abc import ModelABC |
25 | 26 | from tiatoolbox.type_hints import IntPair, Resolution, Units |
26 | 27 | from tiatoolbox.wsicore import WSIReader |
27 | 28 |
|
@@ -139,137 +140,42 @@ class NucleusDetector(SemanticSegmentor): |
139 | 140 |
|
140 | 141 | """ |
141 | 142 |
|
142 | | - def run( |
| 143 | + def __init__( |
143 | 144 | self: NucleusDetector, |
144 | | - images: list[os.PathLike | Path | WSIReader] | np.ndarray, |
| 145 | + model: str | ModelABC, |
| 146 | + batch_size: int = 8, |
| 147 | + num_workers: int = 0, |
| 148 | + weights: str | Path | None = None, |
145 | 149 | *, |
146 | | - masks: list[os.PathLike | Path] | np.ndarray | None = None, |
147 | | - input_resolutions: list[dict[Units, Resolution]] | None = None, |
148 | | - patch_input_shape: IntPair | None = None, |
149 | | - ioconfig: IOSegmentorConfig | None = None, |
150 | | - patch_mode: bool = True, |
151 | | - save_dir: os.PathLike | Path | None = None, |
152 | | - overwrite: bool = False, |
153 | | - output_type: str = "dict", |
154 | | - **kwargs: Unpack[NucleusDetectorRunParams], |
155 | | - ) -> AnnotationStore | Path | str | dict | list[Path]: |
156 | | - """Run the nucleus detection engine on input images. |
157 | | -
|
158 | | - This method orchestrates the full inference pipeline, including preprocessing, |
159 | | - model inference, post-processing, and saving results. It supports both |
160 | | - patch-level and whole slide image (WSI) modes. |
| 150 | + device: str = "cpu", |
| 151 | + verbose: bool = True, |
| 152 | + ) -> None: |
| 153 | + """Initialize :class:`NucleusDetector`. |
161 | 154 |
|
162 | 155 | Args: |
163 | | - images (list[PathLike | WSIReader] | np.ndarray): |
164 | | - Input images or patches. Can be a list of file paths, WSIReader objects, |
165 | | - or a NumPy array of image patches. |
166 | | - masks (list[PathLike] | np.ndarray | None): |
167 | | - Optional masks for WSI processing. Only used when `patch_mode` is False. |
168 | | - input_resolutions (list[dict[Units, Resolution]] | None): |
169 | | - Resolution settings for input heads. Supported units are `level`, |
170 | | - `power`, and `mpp`. Keys should be "units" and "resolution", e.g., |
171 | | - [{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for |
172 | | - details. |
173 | | - patch_input_shape (IntPair | None): |
174 | | - Shape of input patches (height, width), requested at read |
175 | | - resolution. Must be positive. |
176 | | - ioconfig (IOSegmentorConfig | None): |
177 | | - IO configuration for patch extraction and resolution. |
178 | | - patch_mode (bool): |
179 | | - Whether to treat input as patches (`True`) or WSIs (`False`). Default |
180 | | - is True. |
181 | | - save_dir (PathLike | None): |
182 | | - Directory to save output files. Required for WSI mode. |
183 | | - overwrite (bool): |
184 | | - Whether to overwrite existing output files. Default is False. |
185 | | - output_type (str): |
186 | | - Desired output format: "dict", "zarr", or "annotationstore". Default |
187 | | - is "dict". |
188 | | - **kwargs (NucleusDetectorRunParams): |
189 | | - Additional runtime parameters to configure segmentation. |
190 | | -
|
191 | | - Optional Keys: |
192 | | - auto_get_mask (bool): |
193 | | - Whether to automatically generate segmentation masks using |
194 | | - `wsireader.tissue_mask()` during processing. |
195 | | - batch_size (int): |
196 | | - Number of image patches to feed to the model in a forward pass. |
197 | | - class_dict (dict): |
198 | | - Optional dictionary mapping classification outputs to |
199 | | - class names. |
200 | | - device (str): |
201 | | - Device to run the model on (e.g., "cpu", "cuda"). |
202 | | - labels (list): |
203 | | - Optional labels for input images. Only a single label per image |
204 | | - is supported. |
205 | | - memory_threshold (int): |
206 | | - Memory usage threshold (in percentage) to |
207 | | - trigger caching behavior. |
208 | | - num_workers (int): |
209 | | - Number of workers used in DataLoader. |
210 | | - output_file (str): |
211 | | - Output file name for saving results (e.g., .zarr or .db). |
212 | | - output_resolutions (Resolution): |
213 | | - Resolution used for writing output predictions. |
214 | | - patch_output_shape (tuple[int, int]): |
215 | | - Shape of output patches (height, width). |
216 | | - min_distance (int): |
217 | | - Minimum distance separating two nuclei (in pixels). |
218 | | - postproc_tile_shape (tuple[int, int]): |
219 | | - Tile shape (height, width) for post-processing (in pixels). |
220 | | - return_labels (bool): |
221 | | - Whether to return labels with predictions. |
222 | | - return_probabilities (bool): |
223 | | - Whether to return per-class probabilities. |
224 | | - scale_factor (tuple[float, float]): |
225 | | - Scale factor for converting annotations to baseline resolution. |
226 | | - Typically model_mpp / slide_mpp. |
227 | | - stride_shape (tuple[int, int]): |
228 | | - Stride used during WSI processing. |
229 | | - Defaults to patch_input_shape. |
230 | | - verbose (bool): |
231 | | - Whether to output logging information. |
232 | | -
|
233 | | - Returns: |
234 | | - AnnotationStore | Path | str | dict | list[Path]: |
235 | | - - If `patch_mode` is True: returns predictions or path to saved output. |
236 | | - - If `patch_mode` is False: returns a dictionary mapping each WSI |
237 | | - to its output path. |
238 | | -
|
239 | | - Examples: |
240 | | - >>> wsis = ['wsi1.svs', 'wsi2.svs'] |
241 | | - >>> image_patches = [np.ndarray, np.ndarray] |
242 | | - >>> nuc_detector = NucleusDetector(model="sccnn-conic") |
243 | | - >>> output = nuc_detector.run(image_patches, patch_mode=True) |
244 | | - >>> output |
245 | | - ... "/path/to/Output.db" |
246 | | -
|
247 | | - >>> output = nuc_detector.run( |
248 | | - ... image_patches, |
249 | | - ... patch_mode=True, |
250 | | - ... output_type="zarr" |
251 | | - ... ) |
252 | | - >>> output |
253 | | - ... "/path/to/Output.zarr" |
254 | | -
|
255 | | - >>> output = nuc_detector.run(wsis, patch_mode=False) |
256 | | - >>> output.keys() |
257 | | - ... ['wsi1.svs', 'wsi2.svs'] |
258 | | - >>> output['wsi1.svs'] |
259 | | - ... "/path/to/wsi1.db" |
| 156 | + model (str | ModelABC): |
| 157 | + A PyTorch model instance or name of a pretrained model from TIAToolbox. |
| 158 | + If a string is provided, the corresponding pretrained weights will be |
| 159 | + downloaded unless overridden via `weights`. |
| 160 | + batch_size (int): |
| 161 | + Number of image patches processed per forward pass. Default is 8. |
| 162 | + num_workers (int): |
| 163 | + Number of workers for data loading. Default is 0. |
| 164 | + weights (str | Path | None): |
| 165 | + Path to model weights. If None, default weights are used. |
| 166 | + device (str): |
| 167 | + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". |
| 168 | + verbose (bool): |
| 169 | + Whether to enable verbose logging. Default is True. |
260 | 170 |
|
261 | 171 | """ |
262 | | - return super().run( |
263 | | - images=images, |
264 | | - masks=masks, |
265 | | - input_resolutions=input_resolutions, |
266 | | - patch_input_shape=patch_input_shape, |
267 | | - ioconfig=ioconfig, |
268 | | - patch_mode=patch_mode, |
269 | | - save_dir=save_dir, |
270 | | - overwrite=overwrite, |
271 | | - output_type=output_type, |
272 | | - **kwargs, |
| 172 | + super().__init__( |
| 173 | + model=model, |
| 174 | + batch_size=batch_size, |
| 175 | + num_workers=num_workers, |
| 176 | + weights=weights, |
| 177 | + device=device, |
| 178 | + verbose=verbose, |
273 | 179 | ) |
274 | 180 |
|
275 | 181 | def post_process_patches( |
@@ -731,3 +637,136 @@ def save_detection_arrays_to_store( |
731 | 637 | return save_path |
732 | 638 |
|
733 | 639 | return store |
| 640 | + |
| 641 | + def run( |
| 642 | + self: NucleusDetector, |
| 643 | + images: list[os.PathLike | Path | WSIReader] | np.ndarray, |
| 644 | + *, |
| 645 | + masks: list[os.PathLike | Path] | np.ndarray | None = None, |
| 646 | + input_resolutions: list[dict[Units, Resolution]] | None = None, |
| 647 | + patch_input_shape: IntPair | None = None, |
| 648 | + ioconfig: IOSegmentorConfig | None = None, |
| 649 | + patch_mode: bool = True, |
| 650 | + save_dir: os.PathLike | Path | None = None, |
| 651 | + overwrite: bool = False, |
| 652 | + output_type: str = "dict", |
| 653 | + **kwargs: Unpack[NucleusDetectorRunParams], |
| 654 | + ) -> AnnotationStore | Path | str | dict | list[Path]: |
| 655 | + """Run the nucleus detection engine on input images. |
| 656 | +
|
| 657 | + This method orchestrates the full inference pipeline, including preprocessing, |
| 658 | + model inference, post-processing, and saving results. It supports both |
| 659 | + patch-level and whole slide image (WSI) modes. |
| 660 | +
|
| 661 | + Args: |
| 662 | + images (list[PathLike | WSIReader] | np.ndarray): |
| 663 | + Input images or patches. Can be a list of file paths, WSIReader objects, |
| 664 | + or a NumPy array of image patches. |
| 665 | + masks (list[PathLike] | np.ndarray | None): |
| 666 | + Optional masks for WSI processing. Only used when `patch_mode` is False. |
| 667 | + input_resolutions (list[dict[Units, Resolution]] | None): |
| 668 | + Resolution settings for input heads. Supported units are `level`, |
| 669 | + `power`, and `mpp`. Keys should be "units" and "resolution", e.g., |
| 670 | + [{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for |
| 671 | + details. |
| 672 | + patch_input_shape (IntPair | None): |
| 673 | + Shape of input patches (height, width), requested at read |
| 674 | + resolution. Must be positive. |
| 675 | + ioconfig (IOSegmentorConfig | None): |
| 676 | + IO configuration for patch extraction and resolution. |
| 677 | + patch_mode (bool): |
| 678 | + Whether to treat input as patches (`True`) or WSIs (`False`). Default |
| 679 | + is True. |
| 680 | + save_dir (PathLike | None): |
| 681 | + Directory to save output files. Required for WSI mode. |
| 682 | + overwrite (bool): |
| 683 | + Whether to overwrite existing output files. Default is False. |
| 684 | + output_type (str): |
| 685 | + Desired output format: "dict", "zarr", or "annotationstore". Default |
| 686 | + is "dict". |
| 687 | + **kwargs (NucleusDetectorRunParams): |
| 688 | + Additional runtime parameters to configure segmentation. |
| 689 | +
|
| 690 | + Optional Keys: |
| 691 | + auto_get_mask (bool): |
| 692 | + Whether to automatically generate segmentation masks using |
| 693 | + `wsireader.tissue_mask()` during processing. |
| 694 | + batch_size (int): |
| 695 | + Number of image patches to feed to the model in a forward pass. |
| 696 | + class_dict (dict): |
| 697 | + Optional dictionary mapping classification outputs to |
| 698 | + class names. |
| 699 | + device (str): |
| 700 | + Device to run the model on (e.g., "cpu", "cuda"). |
| 701 | + labels (list): |
| 702 | + Optional labels for input images. Only a single label per image |
| 703 | + is supported. |
| 704 | + memory_threshold (int): |
| 705 | + Memory usage threshold (in percentage) to |
| 706 | + trigger caching behavior. |
| 707 | + num_workers (int): |
| 708 | + Number of workers used in DataLoader. |
| 709 | + output_file (str): |
| 710 | + Output file name for saving results (e.g., .zarr or .db). |
| 711 | + output_resolutions (Resolution): |
| 712 | + Resolution used for writing output predictions. |
| 713 | + patch_output_shape (tuple[int, int]): |
| 714 | + Shape of output patches (height, width). |
| 715 | + min_distance (int): |
| 716 | + Minimum distance separating two nuclei (in pixels). |
| 717 | + postproc_tile_shape (tuple[int, int]): |
| 718 | + Tile shape (height, width) for post-processing (in pixels). |
| 719 | + return_labels (bool): |
| 720 | + Whether to return labels with predictions. |
| 721 | + return_probabilities (bool): |
| 722 | + Whether to return per-class probabilities. |
| 723 | + scale_factor (tuple[float, float]): |
| 724 | + Scale factor for converting annotations to baseline resolution. |
| 725 | + Typically model_mpp / slide_mpp. |
| 726 | + stride_shape (tuple[int, int]): |
| 727 | + Stride used during WSI processing. |
| 728 | + Defaults to patch_input_shape. |
| 729 | + verbose (bool): |
| 730 | + Whether to output logging information. |
| 731 | +
|
| 732 | + Returns: |
| 733 | + AnnotationStore | Path | str | dict | list[Path]: |
| 734 | + - If `patch_mode` is True: returns predictions or path to saved output. |
| 735 | + - If `patch_mode` is False: returns a dictionary mapping each WSI |
| 736 | + to its output path. |
| 737 | +
|
| 738 | + Examples: |
| 739 | + >>> wsis = ['wsi1.svs', 'wsi2.svs'] |
| 740 | + >>> image_patches = [np.ndarray, np.ndarray] |
| 741 | + >>> nuc_detector = NucleusDetector(model="sccnn-conic") |
| 742 | + >>> output = nuc_detector.run(image_patches, patch_mode=True) |
| 743 | + >>> output |
| 744 | + ... "/path/to/Output.db" |
| 745 | +
|
| 746 | + >>> output = nuc_detector.run( |
| 747 | + ... image_patches, |
| 748 | + ... patch_mode=True, |
| 749 | + ... output_type="zarr" |
| 750 | + ... ) |
| 751 | + >>> output |
| 752 | + ... "/path/to/Output.zarr" |
| 753 | +
|
| 754 | + >>> output = nuc_detector.run(wsis, patch_mode=False) |
| 755 | + >>> output.keys() |
| 756 | + ... ['wsi1.svs', 'wsi2.svs'] |
| 757 | + >>> output['wsi1.svs'] |
| 758 | + ... "/path/to/wsi1.db" |
| 759 | +
|
| 760 | + """ |
| 761 | + return super().run( |
| 762 | + images=images, |
| 763 | + masks=masks, |
| 764 | + input_resolutions=input_resolutions, |
| 765 | + patch_input_shape=patch_input_shape, |
| 766 | + ioconfig=ioconfig, |
| 767 | + patch_mode=patch_mode, |
| 768 | + save_dir=save_dir, |
| 769 | + overwrite=overwrite, |
| 770 | + output_type=output_type, |
| 771 | + **kwargs, |
| 772 | + ) |
0 commit comments