Skip to content

Commit 193c587

Browse files
committed
✨ Define SemanticSegmentor with the New EngineABC
1 parent 8c2f50b commit 193c587

File tree

3 files changed

+122
-24
lines changed

3 files changed

+122
-24
lines changed

tiatoolbox/models/engine/engine_abc.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,6 @@ class EngineABCRunParams(TypedDict, total=False):
120120
Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine.
121121
return_labels (bool):
122122
Whether to return the labels with the predictions.
123-
merge_predictions (bool):
124-
Whether to merge the predictions to form a 2-dimensional
125-
map into a single file from a WSI.
126-
This is only applicable if `patch_mode` is False in inference.
127123
num_loader_workers (int):
128124
Number of workers used in :class:`torch.utils.data.DataLoader`.
129125
num_post_proc_workers (int):
@@ -165,7 +161,6 @@ class EngineABCRunParams(TypedDict, total=False):
165161
class_dict: dict
166162
device: str
167163
ioconfig: ModelIOConfigABC
168-
merge_predictions: bool
169164
num_loader_workers: int
170165
num_post_proc_workers: int
171166
output_file: str
@@ -248,10 +243,6 @@ class EngineABC(ABC):
248243
Runtime ioconfig.
249244
return_labels (bool):
250245
Whether to return the labels with the predictions.
251-
merge_predictions (bool):
252-
Whether to merge the predictions to form a 2-dimensional
253-
map. This is only applicable if `patch_mode` is False in inference.
254-
Default is False.
255246
resolution (Resolution):
256247
Resolution used for reading the image. Please see
257248
:obj:`WSIReader` for details.
@@ -293,8 +284,6 @@ class EngineABC(ABC):
293284
Number of workers to postprocess the results of the model.
294285
return_labels (bool):
295286
Whether to return the output labels. Default value is False.
296-
merge_predictions (bool):
297-
Whether to merge WSI predictions into a single file. Default value is False.
298287
resolution (Resolution):
299288
Resolution used for reading the image. Please see
300289
:class:`WSIReader` for details.
@@ -374,7 +363,6 @@ def __init__(
374363
self.cache_mode: bool = False
375364
self.cache_size: int = self.batch_size if self.batch_size else 10000
376365
self.labels: list | None = None
377-
self.merge_predictions: bool = False
378366
self.num_loader_workers = num_loader_workers
379367
self.num_post_proc_workers = num_post_proc_workers
380368
self.patch_input_shape: IntPair | None = None
@@ -1194,8 +1182,6 @@ def run(
11941182
- img_path: path of the input image.
11951183
- raw: path to save location for raw prediction,
11961184
saved in .json.
1197-
- merged: path to .npy contain merged
1198-
predictions if `merge_predictions` is `True`.
11991185
12001186
Examples:
12011187
>>> wsis = ['wsi1.svs', 'wsi2.svs']

tiatoolbox/models/engine/patch_predictor.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Defines Abstract Base Class for TIAToolbox Model Engines."""
1+
"""Defines PatchPredictor Engine."""
22

33
from __future__ import annotations
44

@@ -25,7 +25,7 @@
2525
class PatchPredictor(EngineABC):
2626
r"""Patch level predictor for digital histology images.
2727
28-
The models provided by tiatoolbox should give the following results:
28+
The models provided by TIAToolbox should give the following results:
2929
3030
.. list-table:: PatchPredictor performance on the Kather100K dataset [1]
3131
:widths: 15 15
@@ -176,10 +176,6 @@ class PatchPredictor(EngineABC):
176176
Runtime ioconfig.
177177
return_labels (bool):
178178
Whether to return the labels with the predictions.
179-
merge_predictions (bool):
180-
Whether to merge the predictions to form a 2-dimensional
181-
map. This is only applicable if `patch_mode` is False in inference.
182-
Default is False.
183179
resolution (Resolution):
184180
Resolution used for reading the image. Please see
185181
:obj:`WSIReader` for details.
@@ -221,8 +217,6 @@ class PatchPredictor(EngineABC):
221217
Number of workers to postprocess the results of the model.
222218
return_labels (bool):
223219
Whether to return the output labels. Default value is False.
224-
merge_predictions (bool):
225-
Whether to merge WSI predictions into a single file. Default value is False.
226220
resolution (Resolution):
227221
Resolution used for reading the image. Please see
228222
:class:`WSIReader` for details.
@@ -482,8 +476,6 @@ def run(
482476
- img_path: path of the input image.
483477
- raw: path to save location for raw prediction,
484478
saved in .json.
485-
- merged: path to .npy contain merged
486-
predictions if `merge_predictions` is `True`.
487479
488480
Examples:
489481
>>> wsis = ['wsi1.svs', 'wsi2.svs']
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Defines SemanticSegmentor Engine."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
from .patch_predictor import PatchPredictor
8+
9+
if TYPE_CHECKING: # pragma: no cover
10+
from pathlib import Path
11+
12+
from tiatoolbox.models.models_abc import ModelABC
13+
14+
15+
class SemanticSegmentor(PatchPredictor):
16+
"""Pixel-wise segmentation predictor.
17+
18+
The tiatoolbox model should produce the following results on the BCSS dataset
19+
using fcn_resnet50_unet-bcss.
20+
21+
.. list-table:: Semantic segmentation performance on the BCSS dataset
22+
:widths: 15 15 15 15 15 15 15
23+
:header-rows: 1
24+
25+
* -
26+
- Tumour
27+
- Stroma
28+
- Inflammatory
29+
- Necrosis
30+
- Other
31+
- All
32+
* - Amgad et al.
33+
- 0.851
34+
- 0.800
35+
- 0.712
36+
- 0.723
37+
- 0.666
38+
- 0.750
39+
* - TIAToolbox
40+
- 0.885
41+
- 0.825
42+
- 0.761
43+
- 0.765
44+
- 0.581
45+
- 0.763
46+
47+
Note, if `model` is supplied in the arguments, it will ignore the
48+
`pretrained_model` and `pretrained_weights` arguments.
49+
50+
Args:
51+
model (nn.Module):
52+
Use externally defined PyTorch model for prediction with
53+
weights already loaded. Default is `None`. If provided,
54+
`pretrained_model` argument is ignored.
55+
pretrained_model (str):
56+
Name of the existing models support by tiatoolbox for
57+
processing the data. For a full list of pretrained models,
58+
refer to the `docs
59+
<https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_.
60+
By default, the corresponding pretrained weights will also
61+
be downloaded. However, you can override with your own set
62+
of weights via the `pretrained_weights` argument. Argument
63+
is case-insensitive.
64+
pretrained_weights (str):
65+
Path to the weight of the corresponding `pretrained_model`.
66+
batch_size (int):
67+
Number of images fed into the model each time.
68+
num_loader_workers (int):
69+
Number of workers to load the data. Take note that they will
70+
also perform preprocessing.
71+
num_postproc_workers (int):
72+
This value is there to maintain input compatibility with
73+
`tiatoolbox.models.classification` and is not used.
74+
verbose (bool):
75+
Whether to output logging information.
76+
dataset_class (obj):
77+
Dataset class to be used instead of default.
78+
auto_generate_mask (bool):
79+
To automatically generate tile/WSI tissue mask if is not
80+
provided.
81+
82+
Attributes:
83+
process_prediction_per_batch (bool):
84+
A flag to denote whether post-processing for inference
85+
output is applied after each batch or after finishing an entire
86+
tile or WSI.
87+
88+
Examples:
89+
>>> # Sample output of a network
90+
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
91+
>>> predictor = SemanticSegmentor(model='fcn-tissue_mask')
92+
>>> output = predictor.predict(wsis, mode='wsi')
93+
>>> list(output.keys())
94+
[('A/wsi.svs', 'output/0.raw') , ('B/wsi.svs', 'output/1.raw')]
95+
>>> # if a network have 2 output heads, each head output of 'A/wsi.svs'
96+
>>> # will be respectively stored in 'output/0.raw.0', 'output/0.raw.1'
97+
98+
"""
99+
100+
def __init__(
101+
self: SemanticSegmentor,
102+
model: str | ModelABC,
103+
batch_size: int = 8,
104+
num_loader_workers: int = 0,
105+
num_post_proc_workers: int = 0,
106+
weights: str | Path | None = None,
107+
*,
108+
device: str = "cpu",
109+
verbose: bool = True,
110+
) -> None:
111+
"""Initialize :class:`SemanticSegmentor`."""
112+
super().__init__(
113+
model=model,
114+
batch_size=batch_size,
115+
num_loader_workers=num_loader_workers,
116+
num_post_proc_workers=num_post_proc_workers,
117+
weights=weights,
118+
device=device,
119+
verbose=verbose,
120+
)

0 commit comments

Comments
 (0)