|
| 1 | +"""Define DeepFeatureExtractor class.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from typing_extensions import Unpack |
| 9 | + |
| 10 | +from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset |
| 11 | + |
| 12 | +from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams |
| 13 | + |
| 14 | +if TYPE_CHECKING: # pragma: no cover |
| 15 | + import os |
| 16 | + from collections.abc import Callable |
| 17 | + from pathlib import Path |
| 18 | + |
| 19 | + from tiatoolbox.annotation import AnnotationStore |
| 20 | + from tiatoolbox.models.engine.io_config import IOSegmentorConfig |
| 21 | + from tiatoolbox.models.models_abc import ModelABC |
| 22 | + from tiatoolbox.wsicore import WSIReader |
| 23 | + |
| 24 | + |
| 25 | +class DeepFeatureExtractor(SemanticSegmentor): |
| 26 | + """Generic CNN Feature Extractor. |
| 27 | +
|
| 28 | + AN engine for using any CNN model as a feature extractor. Note, if |
| 29 | + `model` is supplied in the arguments, it will ignore the |
| 30 | + `pretrained_model` and `pretrained_weights` arguments. |
| 31 | +
|
| 32 | + Args: |
| 33 | + model (nn.Module): |
| 34 | + Use externally defined PyTorch model for prediction with |
| 35 | + weights already loaded. Default is `None`. If provided, |
| 36 | + `pretrained_model` argument is ignored. |
| 37 | + pretrained_model (str): |
| 38 | + Name of the existing models support by tiatoolbox for |
| 39 | + processing the data. By default, the corresponding |
| 40 | + pretrained weights will also be downloaded. However, you can |
| 41 | + override with your own set of weights via the |
| 42 | + `pretrained_weights` argument. Argument is case-insensitive. |
| 43 | + Refer to |
| 44 | + :class:`tiatoolbox.models.architecture.vanilla.CNNBackbone` |
| 45 | + for list of supported pretrained models. |
| 46 | + pretrained_weights (str): |
| 47 | + Path to the weight of the corresponding `pretrained_model`. |
| 48 | + batch_size (int): |
| 49 | + Number of images fed into the model each time. |
| 50 | + num_loader_workers (int): |
| 51 | + Number of workers to load the data. Take note that they will |
| 52 | + also perform preprocessing. |
| 53 | + num_postproc_workers (int): |
| 54 | + This value is there to maintain input compatibility with |
| 55 | + `tiatoolbox.models.classification` and is not used. |
| 56 | + verbose (bool): |
| 57 | + Whether to output logging information. |
| 58 | + dataset_class (obj): |
| 59 | + Dataset class to be used instead of default. |
| 60 | + auto_generate_mask(bool): |
| 61 | + To automatically generate tile/WSI tissue mask if is not |
| 62 | + provided. |
| 63 | +
|
| 64 | + Examples: |
| 65 | + >>> # Sample output of a network |
| 66 | + >>> from tiatoolbox.models.architecture.vanilla import CNNBackbone |
| 67 | + >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] |
| 68 | + >>> # create resnet50 with pytorch pretrained weights |
| 69 | + >>> model = CNNBackbone('resnet50') |
| 70 | + >>> predictor = DeepFeatureExtractor(model=model) |
| 71 | + >>> output = predictor.predict(wsis, mode='wsi') |
| 72 | + >>> list(output.keys()) |
| 73 | + [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] |
| 74 | + >>> # If a network have 2 output heads, for 'A/wsi.svs', |
| 75 | + >>> # there will be 3 outputs, and they are respectively stored at |
| 76 | + >>> # 'output/0.position.npy' # will always be output |
| 77 | + >>> # 'output/0.features.0.npy' # output of head 0 |
| 78 | + >>> # 'output/0.features.1.npy' # output of head 1 |
| 79 | + >>> # Each file will contain a same number of items, and the item at each |
| 80 | + >>> # index corresponds to 1 patch. The item in `.*position.npy` will |
| 81 | + >>> # be the corresponding patch bounding box. The box coordinates are at |
| 82 | + >>> # the inference resolution defined within the provided `ioconfig`. |
| 83 | +
|
| 84 | + """ |
| 85 | + |
| 86 | + def __init__( |
| 87 | + self: DeepFeatureExtractor, |
| 88 | + model: str | ModelABC, |
| 89 | + batch_size: int = 8, |
| 90 | + num_workers: int = 0, |
| 91 | + weights: str | Path | None = None, |
| 92 | + dataset_class: Callable = WSIStreamDataset, |
| 93 | + *, |
| 94 | + device: str = "cpu", |
| 95 | + verbose: bool = True, |
| 96 | + ) -> None: |
| 97 | + """Initialize :class:`DeepFeatureExtractor`.""" |
| 98 | + super().__init__( |
| 99 | + model=model, |
| 100 | + batch_size=batch_size, |
| 101 | + num_workers=num_workers, |
| 102 | + weights=weights, |
| 103 | + device=device, |
| 104 | + verbose=verbose, |
| 105 | + ) |
| 106 | + self.process_prediction_per_batch = False |
| 107 | + self.dataset_class = dataset_class |
| 108 | + |
| 109 | + def _process_predictions( |
| 110 | + self: DeepFeatureExtractor, |
| 111 | + cum_batch_predictions: list, |
| 112 | + wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002 |
| 113 | + ioconfig: IOSegmentorConfig, |
| 114 | + save_path: str, |
| 115 | + cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002 |
| 116 | + ) -> None: |
| 117 | + """Define how the aggregated predictions are processed. |
| 118 | +
|
| 119 | + This includes merging the prediction if necessary and also |
| 120 | + saving afterward. |
| 121 | +
|
| 122 | + Args: |
| 123 | + cum_batch_predictions (list): |
| 124 | + List of batch predictions. Each item within the list |
| 125 | + should be of (location, patch_predictions). |
| 126 | + wsi_reader (:class:`WSIReader`): |
| 127 | + A reader for the image where the predictions come from. |
| 128 | + Not used here. Added for consistency with the API. |
| 129 | + ioconfig (:class:`IOSegmentorConfig`): |
| 130 | + A configuration object contains input and output |
| 131 | + information. |
| 132 | + save_path (str): |
| 133 | + Root path to save current WSI predictions. |
| 134 | + cache_dir (str): |
| 135 | + Root path to cache current WSI data. |
| 136 | + Not used here. Added for consistency with the API. |
| 137 | +
|
| 138 | + """ |
| 139 | + # assume prediction_list is N, each item has L output elements |
| 140 | + location_list, prediction_list = list(zip(*cum_batch_predictions, strict=False)) |
| 141 | + # Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of output |
| 142 | + # patch, this can exceed the image bound at the requested resolution |
| 143 | + # remove singleton due to split. |
| 144 | + location_list = np.array([v[0] for v in location_list]) |
| 145 | + np.save(f"{save_path}.position.npy", location_list) |
| 146 | + for idx, _ in enumerate(ioconfig.output_resolutions): |
| 147 | + # assume resolution idx to be in the same order as L |
| 148 | + # 0 idx is to remove singleton without removing other axes singleton |
| 149 | + prediction_list = [v[idx][0] for v in prediction_list] |
| 150 | + prediction_list = np.array(prediction_list) |
| 151 | + np.save(f"{save_path}.features.{idx}.npy", prediction_list) |
| 152 | + |
| 153 | + def run( |
| 154 | + self: DeepFeatureExtractor, |
| 155 | + images: list[os.PathLike | Path | WSIReader] | np.ndarray, |
| 156 | + masks: list[os.PathLike | Path] | np.ndarray | None = None, |
| 157 | + labels: list | None = None, |
| 158 | + ioconfig: IOSegmentorConfig | None = None, |
| 159 | + *, |
| 160 | + patch_mode: bool = True, |
| 161 | + save_dir: os.PathLike | Path | None = None, |
| 162 | + overwrite: bool = False, |
| 163 | + output_type: str = "dict", |
| 164 | + **kwargs: Unpack[SemanticSegmentorRunParams], |
| 165 | + ) -> AnnotationStore | Path | str | dict | list[Path]: |
| 166 | + """Run the DeepFeatureExtractor engine on input images.""" |
| 167 | + return super().run( |
| 168 | + images=images, |
| 169 | + masks=masks, |
| 170 | + labels=labels, |
| 171 | + ioconfig=ioconfig, |
| 172 | + patch_mode=patch_mode, |
| 173 | + save_dir=save_dir, |
| 174 | + overwrite=overwrite, |
| 175 | + output_type=output_type, |
| 176 | + **kwargs, |
| 177 | + ) |
0 commit comments