|
| 1 | +"""Define Deep Feature Extractor.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING, Callable |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from tiatoolbox.models import SemanticSegmentor, WSIStreamDataset |
| 10 | + |
| 11 | +if TYPE_CHECKING: # pragma: no cover |
| 12 | + from pathlib import Path |
| 13 | + |
| 14 | + import torch |
| 15 | + |
| 16 | + from tiatoolbox.models.engine.io_config import IOSegmentorConfig |
| 17 | + from tiatoolbox.type_hints import IntPair, Resolution, Units |
| 18 | + from tiatoolbox.wsicore.wsireader import WSIReader |
| 19 | + |
| 20 | + |
| 21 | +class DeepFeatureExtractor(SemanticSegmentor): |
| 22 | + """Generic CNN Feature Extractor. |
| 23 | +
|
| 24 | + AN engine for using any CNN model as a feature extractor. Note, if |
| 25 | + `model` is supplied in the arguments, it will ignore the |
| 26 | + `pretrained_model` and `pretrained_weights` arguments. |
| 27 | +
|
| 28 | + Args: |
| 29 | + model (nn.Module): |
| 30 | + Use externally defined PyTorch model for prediction with |
| 31 | + weights already loaded. Default is `None`. If provided, |
| 32 | + `pretrained_model` argument is ignored. |
| 33 | + pretrained_model (str): |
| 34 | + Name of the existing models support by tiatoolbox for |
| 35 | + processing the data. By default, the corresponding |
| 36 | + pretrained weights will also be downloaded. However, you can |
| 37 | + override with your own set of weights via the |
| 38 | + `pretrained_weights` argument. Argument is case-insensitive. |
| 39 | + Refer to |
| 40 | + :class:`tiatoolbox.models.architecture.vanilla.CNNBackbone` |
| 41 | + for list of supported pretrained models. |
| 42 | + pretrained_weights (str): |
| 43 | + Path to the weight of the corresponding `pretrained_model`. |
| 44 | + batch_size (int): |
| 45 | + Number of images fed into the model each time. |
| 46 | + num_loader_workers (int): |
| 47 | + Number of workers to load the data. Take note that they will |
| 48 | + also perform preprocessing. |
| 49 | + num_postproc_workers (int): |
| 50 | + This value is there to maintain input compatibility with |
| 51 | + `tiatoolbox.models.classification` and is not used. |
| 52 | + verbose (bool): |
| 53 | + Whether to output logging information. |
| 54 | + dataset_class (obj): |
| 55 | + Dataset class to be used instead of default. |
| 56 | + auto_generate_mask(bool): |
| 57 | + To automatically generate tile/WSI tissue mask if is not |
| 58 | + provided. |
| 59 | +
|
| 60 | + Examples: |
| 61 | + >>> # Sample output of a network |
| 62 | + >>> from tiatoolbox.models.architecture.vanilla import CNNBackbone |
| 63 | + >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] |
| 64 | + >>> # create resnet50 with pytorch pretrained weights |
| 65 | + >>> model = CNNBackbone('resnet50') |
| 66 | + >>> predictor = DeepFeatureExtractor(model=model) |
| 67 | + >>> output = predictor.predict(wsis, mode='wsi') |
| 68 | + >>> list(output.keys()) |
| 69 | + [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] |
| 70 | + >>> # If a network have 2 output heads, for 'A/wsi.svs', |
| 71 | + >>> # there will be 3 outputs, and they are respectively stored at |
| 72 | + >>> # 'output/0.position.npy' # will always be output |
| 73 | + >>> # 'output/0.features.0.npy' # output of head 0 |
| 74 | + >>> # 'output/0.features.1.npy' # output of head 1 |
| 75 | + >>> # Each file will contain a same number of items, and the item at each |
| 76 | + >>> # index corresponds to 1 patch. The item in `.*position.npy` will |
| 77 | + >>> # be the corresponding patch bounding box. The box coordinates are at |
| 78 | + >>> # the inference resolution defined within the provided `ioconfig`. |
| 79 | +
|
| 80 | + """ |
| 81 | + |
| 82 | + def __init__( |
| 83 | + self: DeepFeatureExtractor, |
| 84 | + batch_size: int = 8, |
| 85 | + num_loader_workers: int = 0, |
| 86 | + num_postproc_workers: int = 0, |
| 87 | + model: torch.nn.Module | None = None, |
| 88 | + pretrained_model: str | None = None, |
| 89 | + pretrained_weights: str | None = None, |
| 90 | + dataset_class: Callable = WSIStreamDataset, |
| 91 | + *, |
| 92 | + verbose: bool = True, |
| 93 | + auto_generate_mask: bool = False, |
| 94 | + ) -> None: |
| 95 | + """Initialize :class:`DeepFeatureExtractor`.""" |
| 96 | + super().__init__( |
| 97 | + batch_size=batch_size, |
| 98 | + num_loader_workers=num_loader_workers, |
| 99 | + num_postproc_workers=num_postproc_workers, |
| 100 | + model=model, |
| 101 | + pretrained_model=pretrained_model, |
| 102 | + pretrained_weights=pretrained_weights, |
| 103 | + verbose=verbose, |
| 104 | + auto_generate_mask=auto_generate_mask, |
| 105 | + dataset_class=dataset_class, |
| 106 | + ) |
| 107 | + self.process_prediction_per_batch = False |
| 108 | + |
| 109 | + def _process_predictions( |
| 110 | + self: DeepFeatureExtractor, |
| 111 | + cum_batch_predictions: list, |
| 112 | + wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002 |
| 113 | + ioconfig: IOSegmentorConfig, |
| 114 | + save_path: str, |
| 115 | + cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002 |
| 116 | + ) -> None: |
| 117 | + """Define how the aggregated predictions are processed. |
| 118 | +
|
| 119 | + This includes merging the prediction if necessary and also |
| 120 | + saving afterward. |
| 121 | +
|
| 122 | + Args: |
| 123 | + cum_batch_predictions (list): |
| 124 | + List of batch predictions. Each item within the list |
| 125 | + should be of (location, patch_predictions). |
| 126 | + wsi_reader (:class:`WSIReader`): |
| 127 | + A reader for the image where the predictions come from. |
| 128 | + Not used here. Added for consistency with the API. |
| 129 | + ioconfig (:class:`IOSegmentorConfig`): |
| 130 | + A configuration object contains input and output |
| 131 | + information. |
| 132 | + save_path (str): |
| 133 | + Root path to save current WSI predictions. |
| 134 | + cache_dir (str): |
| 135 | + Root path to cache current WSI data. |
| 136 | + Not used here. Added for consistency with the API. |
| 137 | +
|
| 138 | + """ |
| 139 | + # assume prediction_list is N, each item has L output elements |
| 140 | + location_list, prediction_list = list(zip(*cum_batch_predictions)) |
| 141 | + # Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of output |
| 142 | + # patch, this can exceed the image bound at the requested resolution |
| 143 | + # remove singleton due to split. |
| 144 | + location_list = np.array([v[0] for v in location_list]) |
| 145 | + np.save(f"{save_path}.position.npy", location_list) |
| 146 | + for idx, _ in enumerate(ioconfig.output_resolutions): |
| 147 | + # assume resolution idx to be in the same order as L |
| 148 | + # 0 idx is to remove singleton without removing other axes singleton |
| 149 | + prediction_list = [v[idx][0] for v in prediction_list] |
| 150 | + prediction_list = np.array(prediction_list) |
| 151 | + np.save(f"{save_path}.features.{idx}.npy", prediction_list) |
| 152 | + |
| 153 | + def predict( # noqa: PLR0913 |
| 154 | + self: DeepFeatureExtractor, |
| 155 | + imgs: list, |
| 156 | + masks: list | None = None, |
| 157 | + mode: str = "tile", |
| 158 | + ioconfig: IOSegmentorConfig | None = None, |
| 159 | + patch_input_shape: IntPair | None = None, |
| 160 | + patch_output_shape: IntPair | None = None, |
| 161 | + stride_shape: IntPair = None, |
| 162 | + resolution: Resolution = 1.0, |
| 163 | + units: Units = "baseline", |
| 164 | + save_dir: str | Path | None = None, |
| 165 | + device: str = "cpu", |
| 166 | + *, |
| 167 | + crash_on_exception: bool = False, |
| 168 | + ) -> list[tuple[Path, Path]]: |
| 169 | + """Make a prediction for a list of input data. |
| 170 | +
|
| 171 | + By default, if the input model at the time of object |
| 172 | + instantiation is a pretrained model in the toolbox as well as |
| 173 | + `patch_input_shape`, `patch_output_shape`, `stride_shape`, |
| 174 | + `resolution`, `units` and `ioconfig` are `None`. The method will |
| 175 | + use the `ioconfig` retrieved together with the pretrained model. |
| 176 | + Otherwise, either `patch_input_shape`, `patch_output_shape`, |
| 177 | + `stride_shape`, `resolution`, `units` or `ioconfig` must be set |
| 178 | + - else a `Value Error` will be raised. |
| 179 | +
|
| 180 | + Args: |
| 181 | + imgs (list, ndarray): |
| 182 | + List of inputs to process. When using `"patch"` mode, |
| 183 | + the input must be either a list of images, a list of |
| 184 | + image file paths or a numpy array of an image list. When |
| 185 | + using `"tile"` or `"wsi"` mode, the input must be a list |
| 186 | + of file paths. |
| 187 | + masks (list): |
| 188 | + List of masks. Only utilised when processing image tiles |
| 189 | + and whole-slide images. Patches are only processed if |
| 190 | + they are within a masked area. If not provided, then a |
| 191 | + tissue mask will be automatically generated for each |
| 192 | + whole-slide image or all image tiles in the entire image |
| 193 | + are processed. |
| 194 | + mode (str): |
| 195 | + Type of input to process. Choose from either `tile` or |
| 196 | + `wsi`. |
| 197 | + ioconfig (:class:`IOSegmentorConfig`): |
| 198 | + Object that defines information about input and output |
| 199 | + placement of patches. When provided, |
| 200 | + `patch_input_shape`, `patch_output_shape`, |
| 201 | + `stride_shape`, `resolution`, and `units` arguments are |
| 202 | + ignored. Otherwise, those arguments will be internally |
| 203 | + converted to a :class:`IOSegmentorConfig` object. |
| 204 | + device (str): |
| 205 | + :class:`torch.device` to run the model. |
| 206 | + Select the device to run the model. Please see |
| 207 | + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device |
| 208 | + for more details on input parameters for device. Default value is "cpu". |
| 209 | + patch_input_shape (IntPair): |
| 210 | + Size of patches input to the model. The values are at |
| 211 | + requested read resolution and must be positive. |
| 212 | + patch_output_shape (tuple): |
| 213 | + Size of patches output by the model. The values are at |
| 214 | + the requested read resolution and must be positive. |
| 215 | + stride_shape (tuple): |
| 216 | + Stride using during tile and WSI processing. The values |
| 217 | + are at requested read resolution and must be positive. |
| 218 | + If not provided, `stride_shape=patch_input_shape` is |
| 219 | + used. |
| 220 | + resolution (Resolution): |
| 221 | + Resolution used for reading the image. |
| 222 | + units (Units): |
| 223 | + Units of resolution used for reading the image. |
| 224 | + save_dir (str): |
| 225 | + Output directory when processing multiple tiles and |
| 226 | + whole-slide images. By default, it is folder `output` |
| 227 | + where the running script is invoked. |
| 228 | + crash_on_exception (bool): |
| 229 | + If `True`, the running loop will crash if there is any |
| 230 | + error during processing a WSI. Otherwise, the loop will |
| 231 | + move on to the next wsi for processing. |
| 232 | +
|
| 233 | + Returns: |
| 234 | + list: |
| 235 | + A list of tuple(input_path, save_path) where |
| 236 | + `input_path` is the path of the input wsi while |
| 237 | + `save_path` corresponds to the output predictions. |
| 238 | +
|
| 239 | + Examples: |
| 240 | + >>> # Sample output of a network |
| 241 | + >>> from tiatoolbox.models.architecture.vanilla import CNNBackbone |
| 242 | + >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] |
| 243 | + >>> # create resnet50 with pytorch pretrained weights |
| 244 | + >>> model = CNNBackbone('resnet50') |
| 245 | + >>> predictor = DeepFeatureExtractor(model=model) |
| 246 | + >>> output = predictor.predict(wsis, mode='wsi') |
| 247 | + >>> list(output.keys()) |
| 248 | + [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] |
| 249 | + >>> # If a network have 2 output heads, for 'A/wsi.svs', |
| 250 | + >>> # there will be 3 outputs, and they are respectively stored at |
| 251 | + >>> # 'output/0.position.npy' # will always be output |
| 252 | + >>> # 'output/0.features.0.npy' # output of head 0 |
| 253 | + >>> # 'output/0.features.1.npy' # output of head 1 |
| 254 | + >>> # Each file will contain a same number of items, and the item at each |
| 255 | + >>> # index corresponds to 1 patch. The item in `.*position.npy` will |
| 256 | + >>> # be the corresponding patch bounding box. The box coordinates are at |
| 257 | + >>> # the inference resolution defined within the provided `ioconfig`. |
| 258 | +
|
| 259 | + """ |
| 260 | + return super().predict( |
| 261 | + imgs=imgs, |
| 262 | + masks=masks, |
| 263 | + mode=mode, |
| 264 | + device=device, |
| 265 | + ioconfig=ioconfig, |
| 266 | + patch_input_shape=patch_input_shape, |
| 267 | + patch_output_shape=patch_output_shape, |
| 268 | + stride_shape=stride_shape, |
| 269 | + resolution=resolution, |
| 270 | + units=units, |
| 271 | + save_dir=save_dir, |
| 272 | + crash_on_exception=crash_on_exception, |
| 273 | + ) |
0 commit comments