Skip to content

Commit cd368bd

Browse files
committed
🆕 Define DeepFeatureExtractor
1 parent b542c9a commit cd368bd

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""Define DeepFeatureExtractor class."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
from typing_extensions import Unpack
9+
10+
from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset
11+
12+
from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams
13+
14+
if TYPE_CHECKING: # pragma: no cover
15+
import os
16+
from collections.abc import Callable
17+
from pathlib import Path
18+
19+
from tiatoolbox.annotation import AnnotationStore
20+
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
21+
from tiatoolbox.models.models_abc import ModelABC
22+
from tiatoolbox.wsicore import WSIReader
23+
24+
25+
class DeepFeatureExtractor(SemanticSegmentor):
26+
"""Generic CNN Feature Extractor.
27+
28+
AN engine for using any CNN model as a feature extractor. Note, if
29+
`model` is supplied in the arguments, it will ignore the
30+
`pretrained_model` and `pretrained_weights` arguments.
31+
32+
Args:
33+
model (nn.Module):
34+
Use externally defined PyTorch model for prediction with
35+
weights already loaded. Default is `None`. If provided,
36+
`pretrained_model` argument is ignored.
37+
pretrained_model (str):
38+
Name of the existing models support by tiatoolbox for
39+
processing the data. By default, the corresponding
40+
pretrained weights will also be downloaded. However, you can
41+
override with your own set of weights via the
42+
`pretrained_weights` argument. Argument is case-insensitive.
43+
Refer to
44+
:class:`tiatoolbox.models.architecture.vanilla.CNNBackbone`
45+
for list of supported pretrained models.
46+
pretrained_weights (str):
47+
Path to the weight of the corresponding `pretrained_model`.
48+
batch_size (int):
49+
Number of images fed into the model each time.
50+
num_loader_workers (int):
51+
Number of workers to load the data. Take note that they will
52+
also perform preprocessing.
53+
num_postproc_workers (int):
54+
This value is there to maintain input compatibility with
55+
`tiatoolbox.models.classification` and is not used.
56+
verbose (bool):
57+
Whether to output logging information.
58+
dataset_class (obj):
59+
Dataset class to be used instead of default.
60+
auto_generate_mask(bool):
61+
To automatically generate tile/WSI tissue mask if is not
62+
provided.
63+
64+
Examples:
65+
>>> # Sample output of a network
66+
>>> from tiatoolbox.models.architecture.vanilla import CNNBackbone
67+
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
68+
>>> # create resnet50 with pytorch pretrained weights
69+
>>> model = CNNBackbone('resnet50')
70+
>>> predictor = DeepFeatureExtractor(model=model)
71+
>>> output = predictor.predict(wsis, mode='wsi')
72+
>>> list(output.keys())
73+
[('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')]
74+
>>> # If a network have 2 output heads, for 'A/wsi.svs',
75+
>>> # there will be 3 outputs, and they are respectively stored at
76+
>>> # 'output/0.position.npy' # will always be output
77+
>>> # 'output/0.features.0.npy' # output of head 0
78+
>>> # 'output/0.features.1.npy' # output of head 1
79+
>>> # Each file will contain a same number of items, and the item at each
80+
>>> # index corresponds to 1 patch. The item in `.*position.npy` will
81+
>>> # be the corresponding patch bounding box. The box coordinates are at
82+
>>> # the inference resolution defined within the provided `ioconfig`.
83+
84+
"""
85+
86+
def __init__(
87+
self: DeepFeatureExtractor,
88+
model: str | ModelABC,
89+
batch_size: int = 8,
90+
num_workers: int = 0,
91+
weights: str | Path | None = None,
92+
dataset_class: Callable = WSIStreamDataset,
93+
*,
94+
device: str = "cpu",
95+
verbose: bool = True,
96+
) -> None:
97+
"""Initialize :class:`DeepFeatureExtractor`."""
98+
super().__init__(
99+
model=model,
100+
batch_size=batch_size,
101+
num_workers=num_workers,
102+
weights=weights,
103+
device=device,
104+
verbose=verbose,
105+
)
106+
self.process_prediction_per_batch = False
107+
self.dataset_class = dataset_class
108+
109+
def _process_predictions(
110+
self: DeepFeatureExtractor,
111+
cum_batch_predictions: list,
112+
wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002
113+
ioconfig: IOSegmentorConfig,
114+
save_path: str,
115+
cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002
116+
) -> None:
117+
"""Define how the aggregated predictions are processed.
118+
119+
This includes merging the prediction if necessary and also
120+
saving afterward.
121+
122+
Args:
123+
cum_batch_predictions (list):
124+
List of batch predictions. Each item within the list
125+
should be of (location, patch_predictions).
126+
wsi_reader (:class:`WSIReader`):
127+
A reader for the image where the predictions come from.
128+
Not used here. Added for consistency with the API.
129+
ioconfig (:class:`IOSegmentorConfig`):
130+
A configuration object contains input and output
131+
information.
132+
save_path (str):
133+
Root path to save current WSI predictions.
134+
cache_dir (str):
135+
Root path to cache current WSI data.
136+
Not used here. Added for consistency with the API.
137+
138+
"""
139+
# assume prediction_list is N, each item has L output elements
140+
location_list, prediction_list = list(zip(*cum_batch_predictions, strict=False))
141+
# Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of output
142+
# patch, this can exceed the image bound at the requested resolution
143+
# remove singleton due to split.
144+
location_list = np.array([v[0] for v in location_list])
145+
np.save(f"{save_path}.position.npy", location_list)
146+
for idx, _ in enumerate(ioconfig.output_resolutions):
147+
# assume resolution idx to be in the same order as L
148+
# 0 idx is to remove singleton without removing other axes singleton
149+
prediction_list = [v[idx][0] for v in prediction_list]
150+
prediction_list = np.array(prediction_list)
151+
np.save(f"{save_path}.features.{idx}.npy", prediction_list)
152+
153+
def run(
154+
self: DeepFeatureExtractor,
155+
images: list[os.PathLike | Path | WSIReader] | np.ndarray,
156+
masks: list[os.PathLike | Path] | np.ndarray | None = None,
157+
labels: list | None = None,
158+
ioconfig: IOSegmentorConfig | None = None,
159+
*,
160+
patch_mode: bool = True,
161+
save_dir: os.PathLike | Path | None = None,
162+
overwrite: bool = False,
163+
output_type: str = "dict",
164+
**kwargs: Unpack[SemanticSegmentorRunParams],
165+
) -> AnnotationStore | Path | str | dict | list[Path]:
166+
"""Run the DeepFeatureExtractor engine on input images."""
167+
return super().run(
168+
images=images,
169+
masks=masks,
170+
labels=labels,
171+
ioconfig=ioconfig,
172+
patch_mode=patch_mode,
173+
save_dir=save_dir,
174+
overwrite=overwrite,
175+
output_type=output_type,
176+
**kwargs,
177+
)

0 commit comments

Comments
 (0)