|
4 | 4 |
|
5 | 5 | from typing import TYPE_CHECKING |
6 | 6 |
|
7 | | -from .patch_predictor import PatchPredictor |
| 7 | +from typing_extensions import Unpack |
| 8 | + |
| 9 | +from .patch_predictor import PatchPredictor, PredictorRunParams |
8 | 10 |
|
9 | 11 | if TYPE_CHECKING: # pragma: no cover |
| 12 | + import os |
10 | 13 | from pathlib import Path |
11 | 14 |
|
| 15 | + import numpy as np |
| 16 | + |
| 17 | + from tiatoolbox.annotation import AnnotationStore |
| 18 | + from tiatoolbox.models.engine.io_config import IOSegmentorConfig |
12 | 19 | from tiatoolbox.models.models_abc import ModelABC |
| 20 | + from tiatoolbox.type_hints import Resolution |
| 21 | + from tiatoolbox.wsicore import WSIReader |
| 22 | + |
| 23 | + |
| 24 | +class SemanticSegmentorRunParams(PredictorRunParams): |
| 25 | + """Class describing the input parameters for the :func:`EngineABC.run()` method. |
| 26 | +
|
| 27 | + Attributes: |
| 28 | + batch_size (int): |
| 29 | + Number of image patches to feed to the model in a forward pass. |
| 30 | + cache_mode (bool): |
| 31 | + Whether to run the Engine in cache_mode. For large datasets, |
| 32 | + we recommend to set this to True to avoid out of memory errors. |
| 33 | + For smaller datasets, the cache_mode is set to False as |
| 34 | + the results can be saved in memory. |
| 35 | + cache_size (int): |
| 36 | + Specifies how many image patches to process in a batch when |
| 37 | + cache_mode is set to True. If cache_size is less than the batch_size |
| 38 | + batch_size is set to cache_size. |
| 39 | + class_dict (dict): |
| 40 | + Optional dictionary mapping classification outputs to class names. |
| 41 | + device (str): |
| 42 | + Select the device to run the model. Please see |
| 43 | + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device |
| 44 | + for more details on input parameters for device. |
| 45 | + ioconfig (ModelIOConfigABC): |
| 46 | + Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. |
| 47 | + return_labels (bool): |
| 48 | + Whether to return the labels with the predictions. |
| 49 | + num_loader_workers (int): |
| 50 | + Number of workers used in :class:`torch.utils.data.DataLoader`. |
| 51 | + num_post_proc_workers (int): |
| 52 | + Number of workers to postprocess the results of the model. |
| 53 | + output_file (str): |
| 54 | + Output file name to save "zarr" or "db". If None, path to output is |
| 55 | + returned by the engine. |
| 56 | + patch_input_shape (tuple): |
| 57 | + Shape of patches input to the model as tuple of height and width (HW). |
| 58 | + Patches are requested at read resolution, not with respect to level 0, |
| 59 | + and must be positive. |
| 60 | + resolution (Resolution): |
| 61 | + Resolution used for reading the image. Please see |
| 62 | + :class:`WSIReader` for details. |
| 63 | + return_probabilities (bool): |
| 64 | + Whether to return per-class probabilities. |
| 65 | + scale_factor (tuple[float, float]): |
| 66 | + The scale factor to use when loading the |
| 67 | + annotations. All coordinates will be multiplied by this factor to allow |
| 68 | + conversion of annotations saved at non-baseline resolution to baseline. |
| 69 | + Should be model_mpp/slide_mpp. |
| 70 | + stride_shape (tuple): |
| 71 | + Stride used during WSI processing. Stride is |
| 72 | + at requested read resolution, not with respect to |
| 73 | + level 0, and must be positive. If not provided, |
| 74 | + `stride_shape=patch_input_shape`. |
| 75 | + units (Units): |
| 76 | + Units of resolution used for reading the image. Choose |
| 77 | + from either `level`, `power` or `mpp`. Please see |
| 78 | + :class:`WSIReader` for details. |
| 79 | + verbose (bool): |
| 80 | + Whether to output logging information. |
| 81 | +
|
| 82 | + """ |
| 83 | + |
| 84 | + patch_output_shape: tuple |
| 85 | + output_resolution: Resolution |
13 | 86 |
|
14 | 87 |
|
15 | 88 | class SemanticSegmentor(PatchPredictor): |
@@ -52,44 +125,128 @@ class SemanticSegmentor(PatchPredictor): |
52 | 125 | Use externally defined PyTorch model for prediction with |
53 | 126 | weights already loaded. Default is `None`. If provided, |
54 | 127 | `pretrained_model` argument is ignored. |
55 | | - pretrained_model (str): |
56 | | - Name of the existing models support by tiatoolbox for |
57 | | - processing the data. For a full list of pretrained models, |
| 128 | + batch_size (int): |
| 129 | + Number of images fed into the model each time. |
| 130 | + num_loader_workers (int): |
| 131 | + Number of workers to load the data using :class:`torch.utils.data.Dataset`. |
| 132 | + Please note that they will also perform preprocessing. Default value is 0. |
| 133 | + num_post_proc_workers (int): |
| 134 | + Number of workers to postprocess the results of the model. |
| 135 | + Default value is 0. |
| 136 | + weights (str or Path): |
| 137 | + Path to the weight of the corresponding `model`. |
| 138 | +
|
| 139 | + >>> engine = SemanticSegmentor( |
| 140 | + ... model="pretrained-model", |
| 141 | + ... weights="/path/to/pretrained-local-weights.pth" |
| 142 | + ... ) |
| 143 | +
|
| 144 | + verbose (bool): |
| 145 | + Whether to output logging information. |
| 146 | + device (str): |
| 147 | + Select the device to run the model. Please see |
| 148 | + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device |
| 149 | + for more details on input parameters for device. Default is "cpu". |
| 150 | + verbose (bool): |
| 151 | + Whether to output logging information. Default value is False. |
| 152 | +
|
| 153 | + Attributes: |
| 154 | + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): |
| 155 | + A list of image patches in NHWC format as a numpy array |
| 156 | + or a list of str/paths to WSIs. |
| 157 | + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): |
| 158 | + A list of tissue masks or binary masks corresponding to processing area of |
| 159 | + input images. These can be a list of numpy arrays or paths to |
| 160 | + the saved image masks. These are only utilized when patch_mode is False. |
| 161 | + Patches are only generated within a masked area. |
| 162 | + If not provided, then a tissue mask will be automatically |
| 163 | + generated for whole slide images. |
| 164 | + patch_mode (str): |
| 165 | + Whether to treat input images as a set of image patches. TIAToolbox defines |
| 166 | + an image as a patch if HWC of the input image matches with the HWC expected |
| 167 | + by the model. If HWC of the input image does not match with the HWC expected |
| 168 | + by the model, then the patch_mode must be set to False which will allow the |
| 169 | + engine to extract patches from the input image. |
| 170 | + In this case, when the patch_mode is False the input images are treated |
| 171 | + as WSIs. Default value is True. |
| 172 | + model (str | ModelABC): |
| 173 | + A PyTorch model or a name of an existing model from the TIAToolbox model zoo |
| 174 | + for processing the data. For a full list of pretrained models, |
58 | 175 | refer to the `docs |
59 | | - <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_. |
| 176 | + <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_ |
60 | 177 | By default, the corresponding pretrained weights will also |
61 | 178 | be downloaded. However, you can override with your own set |
62 | | - of weights via the `pretrained_weights` argument. Argument |
| 179 | + of weights via the `weights` argument. Argument |
63 | 180 | is case-insensitive. |
64 | | - pretrained_weights (str): |
65 | | - Path to the weight of the corresponding `pretrained_model`. |
| 181 | + ioconfig (IOSegmentorConfig): |
| 182 | + Input IO configuration of type :class:`IOSegmentorConfig` to run the Engine. |
| 183 | + _ioconfig (IOSegmentorConfig): |
| 184 | + Runtime ioconfig. |
| 185 | + return_labels (bool): |
| 186 | + Whether to return the labels with the predictions. |
| 187 | + resolution (Resolution): |
| 188 | + Resolution used for reading the image. Please see |
| 189 | + :obj:`WSIReader` for details. |
| 190 | + units (Units): |
| 191 | + Units of resolution used for reading the image. Choose |
| 192 | + from either `level`, `power` or `mpp`. Please see |
| 193 | + :obj:`WSIReader` for details. |
| 194 | + patch_input_shape (tuple): |
| 195 | + Shape of patches input to the model as tupled of HW. Patches are at |
| 196 | + requested read resolution, not with respect to level 0, |
| 197 | + and must be positive. |
| 198 | + stride_shape (tuple): |
| 199 | + Stride used during WSI processing. Stride is |
| 200 | + at requested read resolution, not with respect to |
| 201 | + level 0, and must be positive. If not provided, |
| 202 | + `stride_shape=patch_input_shape`. |
66 | 203 | batch_size (int): |
67 | 204 | Number of images fed into the model each time. |
| 205 | + cache_mode (bool): |
| 206 | + Whether to run the Engine in cache_mode. For large datasets, |
| 207 | + we recommend to set this to True to avoid out of memory errors. |
| 208 | + For smaller datasets, the cache_mode is set to False as |
| 209 | + the results can be saved in memory. cache_mode is always True when |
| 210 | + processing WSIs i.e., when `patch_mode` is False. Default value is False. |
| 211 | + cache_size (int): |
| 212 | + Specifies how many image patches to process in a batch when |
| 213 | + cache_mode is set to True. If cache_size is less than the batch_size |
| 214 | + batch_size is set to cache_size. Default value is 10,000. |
| 215 | + labels (list | None): |
| 216 | + List of labels. Only a single label per image is supported. |
| 217 | + device (str): |
| 218 | + :class:`torch.device` to run the model. |
| 219 | + Select the device to run the model. Please see |
| 220 | + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device |
| 221 | + for more details on input parameters for device. Default value is "cpu". |
68 | 222 | num_loader_workers (int): |
69 | | - Number of workers to load the data. Take note that they will |
70 | | - also perform preprocessing. |
71 | | - num_postproc_workers (int): |
72 | | - This value is there to maintain input compatibility with |
73 | | - `tiatoolbox.models.classification` and is not used. |
| 223 | + Number of workers used in :class:`torch.utils.data.DataLoader`. |
| 224 | + num_post_proc_workers (int): |
| 225 | + Number of workers to postprocess the results of the model. |
| 226 | + return_labels (bool): |
| 227 | + Whether to return the output labels. Default value is False. |
| 228 | + resolution (Resolution): |
| 229 | + Resolution used for reading the image. Please see |
| 230 | + :class:`WSIReader` for details. |
| 231 | + When `patch_mode` is True, the input image patches are expected to be at |
| 232 | + the correct resolution and units. When `patch_mode` is False, the patches |
| 233 | + are extracted at the requested resolution and units. Default value is 1.0. |
| 234 | + units (Units): |
| 235 | + Units of resolution used for reading the image. Choose |
| 236 | + from either `baseline`, `level`, `power` or `mpp`. Please see |
| 237 | + :class:`WSIReader` for details. |
| 238 | + When `patch_mode` is True, the input image patches are expected to be at |
| 239 | + the correct resolution and units. When `patch_mode` is False, the patches |
| 240 | + are extracted at the requested resolution and units. |
| 241 | + Default value is `baseline`. |
74 | 242 | verbose (bool): |
75 | | - Whether to output logging information. |
76 | | - dataset_class (obj): |
77 | | - Dataset class to be used instead of default. |
78 | | - auto_generate_mask (bool): |
79 | | - To automatically generate tile/WSI tissue mask if is not |
80 | | - provided. |
81 | | -
|
82 | | - Attributes: |
83 | | - process_prediction_per_batch (bool): |
84 | | - A flag to denote whether post-processing for inference |
85 | | - output is applied after each batch or after finishing an entire |
86 | | - tile or WSI. |
| 243 | + Whether to output logging information. Default value is False. |
87 | 244 |
|
88 | 245 | Examples: |
89 | 246 | >>> # Sample output of a network |
90 | 247 | >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] |
91 | | - >>> predictor = SemanticSegmentor(model='fcn-tissue_mask') |
92 | | - >>> output = predictor.predict(wsis, mode='wsi') |
| 248 | + >>> segmentor = SemanticSegmentor(model='fcn-tissue_mask') |
| 249 | + >>> output = segmentor.run(wsis, mode='wsi') |
93 | 250 | >>> list(output.keys()) |
94 | 251 | [('A/wsi.svs', 'output/0.raw') , ('B/wsi.svs', 'output/1.raw')] |
95 | 252 | >>> # if a network have 2 output heads, each head output of 'A/wsi.svs' |
@@ -118,3 +275,104 @@ def __init__( |
118 | 275 | device=device, |
119 | 276 | verbose=verbose, |
120 | 277 | ) |
| 278 | + |
| 279 | + def run( |
| 280 | + self: SemanticSegmentor, |
| 281 | + images: list[os | Path | WSIReader] | np.ndarray, |
| 282 | + masks: list[os | Path] | np.ndarray | None = None, |
| 283 | + labels: list | None = None, |
| 284 | + ioconfig: IOSegmentorConfig | None = None, |
| 285 | + *, |
| 286 | + patch_mode: bool = True, |
| 287 | + save_dir: os | Path | None = None, # None will not save output |
| 288 | + overwrite: bool = False, |
| 289 | + output_type: str = "dict", |
| 290 | + **kwargs: Unpack[SemanticSegmentorRunParams], |
| 291 | + ) -> AnnotationStore | Path | str | dict: |
| 292 | + """Run the engine on input images. |
| 293 | +
|
| 294 | + Args: |
| 295 | + images (list, ndarray): |
| 296 | + List of inputs to process. when using `patch` mode, the |
| 297 | + input must be either a list of images, a list of image |
| 298 | + file paths or a numpy array of an image list. |
| 299 | + masks (list | None): |
| 300 | + List of masks. Only utilised when patch_mode is False. |
| 301 | + Patches are only generated within a masked area. |
| 302 | + If not provided, then a tissue mask will be automatically |
| 303 | + generated for whole slide images. |
| 304 | + labels (list | None): |
| 305 | + List of labels. Only a single label per image is supported. |
| 306 | + patch_mode (bool): |
| 307 | + Whether to treat input image as a patch or WSI. |
| 308 | + default = True. |
| 309 | + ioconfig (IOSegmentorConfig): |
| 310 | + IO configuration. |
| 311 | + save_dir (str or pathlib.Path): |
| 312 | + Output directory to save the results. |
| 313 | + If save_dir is not provided when patch_mode is False, |
| 314 | + then for a single image the output is created in the current directory. |
| 315 | + If there are multiple WSIs as input then the user must provide |
| 316 | + path to save directory otherwise an OSError will be raised. |
| 317 | + overwrite (bool): |
| 318 | + Whether to overwrite the results. Default = False. |
| 319 | + output_type (str): |
| 320 | + The format of the output type. "output_type" can be |
| 321 | + "zarr" or "AnnotationStore". Default value is "zarr". |
| 322 | + When saving in the zarr format the output is saved using the |
| 323 | + `python zarr library <https://zarr.readthedocs.io/en/stable/>`__ |
| 324 | + as a zarr group. If the required output type is an "AnnotationStore" |
| 325 | + then the output will be intermediately saved as zarr but converted |
| 326 | + to :class:`AnnotationStore` and saved as a `.db` file |
| 327 | + at the end of the loop. |
| 328 | + **kwargs (PredictorRunParams): |
| 329 | + Keyword Args to update :class:`EngineABC` attributes during runtime. |
| 330 | +
|
| 331 | + Returns: |
| 332 | + (:class:`numpy.ndarray`, dict): |
| 333 | + Model predictions of the input dataset. If multiple |
| 334 | + whole slide images are provided as input, |
| 335 | + or save_output is True, then results are saved to |
| 336 | + `save_dir` and a dictionary indicating save location for |
| 337 | + each input is returned. |
| 338 | +
|
| 339 | + The dict has the following format: |
| 340 | +
|
| 341 | + - img_path: path of the input image. |
| 342 | + - raw: path to save location for raw prediction, |
| 343 | + saved in .json. |
| 344 | +
|
| 345 | + Examples: |
| 346 | + >>> wsis = ['wsi1.svs', 'wsi2.svs'] |
| 347 | + >>> image_patches = [np.ndarray, np.ndarray] |
| 348 | + >>> class SemanticSegmentor(PatchPredictor): |
| 349 | + >>> # Define all Abstract methods. |
| 350 | + >>> ... |
| 351 | + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") |
| 352 | + >>> output = segmentor.run(image_patches, patch_mode=True) |
| 353 | + >>> output |
| 354 | + ... "/path/to/Output.db" |
| 355 | + >>> output = segmentor.run( |
| 356 | + >>> image_patches, |
| 357 | + >>> patch_mode=True, |
| 358 | + >>> output_type="zarr") |
| 359 | + >>> output |
| 360 | + ... "/path/to/Output.zarr" |
| 361 | + >>> output = segmentor.run(wsis, patch_mode=False) |
| 362 | + >>> output.keys() |
| 363 | + ... ['wsi1.svs', 'wsi2.svs'] |
| 364 | + >>> output['wsi1.svs'] |
| 365 | + ... {'/path/to/wsi1.db'} |
| 366 | +
|
| 367 | + """ |
| 368 | + return super().run( |
| 369 | + images=images, |
| 370 | + masks=masks, |
| 371 | + labels=labels, |
| 372 | + ioconfig=ioconfig, |
| 373 | + patch_mode=patch_mode, |
| 374 | + save_dir=save_dir, |
| 375 | + overwrite=overwrite, |
| 376 | + output_type=output_type, |
| 377 | + **kwargs, |
| 378 | + ) |
0 commit comments