Skip to content

Commit 6ccb300

Browse files
committed
⏪ Add DeepFeatureExtractor.
1 parent 3aaed3e commit 6ccb300

File tree

1 file changed

+273
-0
lines changed

1 file changed

+273
-0
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""Define Deep Feature Extractor."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Callable
6+
7+
import numpy as np
8+
9+
from tiatoolbox.models import SemanticSegmentor, WSIStreamDataset
10+
11+
if TYPE_CHECKING: # pragma: no cover
12+
from pathlib import Path
13+
14+
import torch
15+
16+
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
17+
from tiatoolbox.type_hints import IntPair, Resolution, Units
18+
from tiatoolbox.wsicore.wsireader import WSIReader
19+
20+
21+
class DeepFeatureExtractor(SemanticSegmentor):
22+
"""Generic CNN Feature Extractor.
23+
24+
AN engine for using any CNN model as a feature extractor. Note, if
25+
`model` is supplied in the arguments, it will ignore the
26+
`pretrained_model` and `pretrained_weights` arguments.
27+
28+
Args:
29+
model (nn.Module):
30+
Use externally defined PyTorch model for prediction with
31+
weights already loaded. Default is `None`. If provided,
32+
`pretrained_model` argument is ignored.
33+
pretrained_model (str):
34+
Name of the existing models support by tiatoolbox for
35+
processing the data. By default, the corresponding
36+
pretrained weights will also be downloaded. However, you can
37+
override with your own set of weights via the
38+
`pretrained_weights` argument. Argument is case-insensitive.
39+
Refer to
40+
:class:`tiatoolbox.models.architecture.vanilla.CNNBackbone`
41+
for list of supported pretrained models.
42+
pretrained_weights (str):
43+
Path to the weight of the corresponding `pretrained_model`.
44+
batch_size (int):
45+
Number of images fed into the model each time.
46+
num_loader_workers (int):
47+
Number of workers to load the data. Take note that they will
48+
also perform preprocessing.
49+
num_postproc_workers (int):
50+
This value is there to maintain input compatibility with
51+
`tiatoolbox.models.classification` and is not used.
52+
verbose (bool):
53+
Whether to output logging information.
54+
dataset_class (obj):
55+
Dataset class to be used instead of default.
56+
auto_generate_mask(bool):
57+
To automatically generate tile/WSI tissue mask if is not
58+
provided.
59+
60+
Examples:
61+
>>> # Sample output of a network
62+
>>> from tiatoolbox.models.architecture.vanilla import CNNBackbone
63+
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
64+
>>> # create resnet50 with pytorch pretrained weights
65+
>>> model = CNNBackbone('resnet50')
66+
>>> predictor = DeepFeatureExtractor(model=model)
67+
>>> output = predictor.predict(wsis, mode='wsi')
68+
>>> list(output.keys())
69+
[('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')]
70+
>>> # If a network have 2 output heads, for 'A/wsi.svs',
71+
>>> # there will be 3 outputs, and they are respectively stored at
72+
>>> # 'output/0.position.npy' # will always be output
73+
>>> # 'output/0.features.0.npy' # output of head 0
74+
>>> # 'output/0.features.1.npy' # output of head 1
75+
>>> # Each file will contain a same number of items, and the item at each
76+
>>> # index corresponds to 1 patch. The item in `.*position.npy` will
77+
>>> # be the corresponding patch bounding box. The box coordinates are at
78+
>>> # the inference resolution defined within the provided `ioconfig`.
79+
80+
"""
81+
82+
def __init__(
83+
self: DeepFeatureExtractor,
84+
batch_size: int = 8,
85+
num_loader_workers: int = 0,
86+
num_postproc_workers: int = 0,
87+
model: torch.nn.Module | None = None,
88+
pretrained_model: str | None = None,
89+
pretrained_weights: str | None = None,
90+
dataset_class: Callable = WSIStreamDataset,
91+
*,
92+
verbose: bool = True,
93+
auto_generate_mask: bool = False,
94+
) -> None:
95+
"""Initialize :class:`DeepFeatureExtractor`."""
96+
super().__init__(
97+
batch_size=batch_size,
98+
num_loader_workers=num_loader_workers,
99+
num_postproc_workers=num_postproc_workers,
100+
model=model,
101+
pretrained_model=pretrained_model,
102+
pretrained_weights=pretrained_weights,
103+
verbose=verbose,
104+
auto_generate_mask=auto_generate_mask,
105+
dataset_class=dataset_class,
106+
)
107+
self.process_prediction_per_batch = False
108+
109+
def _process_predictions(
110+
self: DeepFeatureExtractor,
111+
cum_batch_predictions: list,
112+
wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002
113+
ioconfig: IOSegmentorConfig,
114+
save_path: str,
115+
cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002
116+
) -> None:
117+
"""Define how the aggregated predictions are processed.
118+
119+
This includes merging the prediction if necessary and also
120+
saving afterward.
121+
122+
Args:
123+
cum_batch_predictions (list):
124+
List of batch predictions. Each item within the list
125+
should be of (location, patch_predictions).
126+
wsi_reader (:class:`WSIReader`):
127+
A reader for the image where the predictions come from.
128+
Not used here. Added for consistency with the API.
129+
ioconfig (:class:`IOSegmentorConfig`):
130+
A configuration object contains input and output
131+
information.
132+
save_path (str):
133+
Root path to save current WSI predictions.
134+
cache_dir (str):
135+
Root path to cache current WSI data.
136+
Not used here. Added for consistency with the API.
137+
138+
"""
139+
# assume prediction_list is N, each item has L output elements
140+
location_list, prediction_list = list(zip(*cum_batch_predictions))
141+
# Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of output
142+
# patch, this can exceed the image bound at the requested resolution
143+
# remove singleton due to split.
144+
location_list = np.array([v[0] for v in location_list])
145+
np.save(f"{save_path}.position.npy", location_list)
146+
for idx, _ in enumerate(ioconfig.output_resolutions):
147+
# assume resolution idx to be in the same order as L
148+
# 0 idx is to remove singleton without removing other axes singleton
149+
prediction_list = [v[idx][0] for v in prediction_list]
150+
prediction_list = np.array(prediction_list)
151+
np.save(f"{save_path}.features.{idx}.npy", prediction_list)
152+
153+
def predict( # noqa: PLR0913
154+
self: DeepFeatureExtractor,
155+
imgs: list,
156+
masks: list | None = None,
157+
mode: str = "tile",
158+
ioconfig: IOSegmentorConfig | None = None,
159+
patch_input_shape: IntPair | None = None,
160+
patch_output_shape: IntPair | None = None,
161+
stride_shape: IntPair = None,
162+
resolution: Resolution = 1.0,
163+
units: Units = "baseline",
164+
save_dir: str | Path | None = None,
165+
device: str = "cpu",
166+
*,
167+
crash_on_exception: bool = False,
168+
) -> list[tuple[Path, Path]]:
169+
"""Make a prediction for a list of input data.
170+
171+
By default, if the input model at the time of object
172+
instantiation is a pretrained model in the toolbox as well as
173+
`patch_input_shape`, `patch_output_shape`, `stride_shape`,
174+
`resolution`, `units` and `ioconfig` are `None`. The method will
175+
use the `ioconfig` retrieved together with the pretrained model.
176+
Otherwise, either `patch_input_shape`, `patch_output_shape`,
177+
`stride_shape`, `resolution`, `units` or `ioconfig` must be set
178+
- else a `Value Error` will be raised.
179+
180+
Args:
181+
imgs (list, ndarray):
182+
List of inputs to process. When using `"patch"` mode,
183+
the input must be either a list of images, a list of
184+
image file paths or a numpy array of an image list. When
185+
using `"tile"` or `"wsi"` mode, the input must be a list
186+
of file paths.
187+
masks (list):
188+
List of masks. Only utilised when processing image tiles
189+
and whole-slide images. Patches are only processed if
190+
they are within a masked area. If not provided, then a
191+
tissue mask will be automatically generated for each
192+
whole-slide image or all image tiles in the entire image
193+
are processed.
194+
mode (str):
195+
Type of input to process. Choose from either `tile` or
196+
`wsi`.
197+
ioconfig (:class:`IOSegmentorConfig`):
198+
Object that defines information about input and output
199+
placement of patches. When provided,
200+
`patch_input_shape`, `patch_output_shape`,
201+
`stride_shape`, `resolution`, and `units` arguments are
202+
ignored. Otherwise, those arguments will be internally
203+
converted to a :class:`IOSegmentorConfig` object.
204+
device (str):
205+
:class:`torch.device` to run the model.
206+
Select the device to run the model. Please see
207+
https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
208+
for more details on input parameters for device. Default value is "cpu".
209+
patch_input_shape (IntPair):
210+
Size of patches input to the model. The values are at
211+
requested read resolution and must be positive.
212+
patch_output_shape (tuple):
213+
Size of patches output by the model. The values are at
214+
the requested read resolution and must be positive.
215+
stride_shape (tuple):
216+
Stride using during tile and WSI processing. The values
217+
are at requested read resolution and must be positive.
218+
If not provided, `stride_shape=patch_input_shape` is
219+
used.
220+
resolution (Resolution):
221+
Resolution used for reading the image.
222+
units (Units):
223+
Units of resolution used for reading the image.
224+
save_dir (str):
225+
Output directory when processing multiple tiles and
226+
whole-slide images. By default, it is folder `output`
227+
where the running script is invoked.
228+
crash_on_exception (bool):
229+
If `True`, the running loop will crash if there is any
230+
error during processing a WSI. Otherwise, the loop will
231+
move on to the next wsi for processing.
232+
233+
Returns:
234+
list:
235+
A list of tuple(input_path, save_path) where
236+
`input_path` is the path of the input wsi while
237+
`save_path` corresponds to the output predictions.
238+
239+
Examples:
240+
>>> # Sample output of a network
241+
>>> from tiatoolbox.models.architecture.vanilla import CNNBackbone
242+
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
243+
>>> # create resnet50 with pytorch pretrained weights
244+
>>> model = CNNBackbone('resnet50')
245+
>>> predictor = DeepFeatureExtractor(model=model)
246+
>>> output = predictor.predict(wsis, mode='wsi')
247+
>>> list(output.keys())
248+
[('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')]
249+
>>> # If a network have 2 output heads, for 'A/wsi.svs',
250+
>>> # there will be 3 outputs, and they are respectively stored at
251+
>>> # 'output/0.position.npy' # will always be output
252+
>>> # 'output/0.features.0.npy' # output of head 0
253+
>>> # 'output/0.features.1.npy' # output of head 1
254+
>>> # Each file will contain a same number of items, and the item at each
255+
>>> # index corresponds to 1 patch. The item in `.*position.npy` will
256+
>>> # be the corresponding patch bounding box. The box coordinates are at
257+
>>> # the inference resolution defined within the provided `ioconfig`.
258+
259+
"""
260+
return super().predict(
261+
imgs=imgs,
262+
masks=masks,
263+
mode=mode,
264+
device=device,
265+
ioconfig=ioconfig,
266+
patch_input_shape=patch_input_shape,
267+
patch_output_shape=patch_output_shape,
268+
stride_shape=stride_shape,
269+
resolution=resolution,
270+
units=units,
271+
save_dir=save_dir,
272+
crash_on_exception=crash_on_exception,
273+
)

0 commit comments

Comments
 (0)