Skip to content

Commit 40e79a1

Browse files
committed
🔨 Add run method to SemanticSegmentor
1 parent 967dba1 commit 40e79a1

File tree

4 files changed

+305
-28
lines changed

4 files changed

+305
-28
lines changed

tests/engines/test_patch_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Test for Patch Predictor."""
1+
"""Test PatchPredictor."""
22

33
from __future__ import annotations
44

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Test SemanticSegmentor."""
2+
3+
from __future__ import annotations
4+
5+
import torch
6+
7+
from tiatoolbox.models.engine.semantic_segmentor_new import SemanticSegmentor
8+
from tiatoolbox.utils import env_detection as toolbox_env
9+
10+
device = "cuda" if toolbox_env.has_gpu() else "cpu"
11+
12+
13+
def test_semantic_segmentor_init() -> None:
14+
"""Tests SemanticSegmentor initialization."""
15+
segmentor = SemanticSegmentor(model="fcn-tissue_mask", device=device)
16+
17+
assert isinstance(segmentor, SemanticSegmentor)
18+
assert isinstance(segmentor.model, torch.nn.Module)

tiatoolbox/models/engine/patch_predictor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ def run(
517517
518518
Examples:
519519
>>> wsis = ['wsi1.svs', 'wsi2.svs']
520+
>>> image_patches = [np.ndarray, np.ndarray]
520521
>>> class PatchPredictor(EngineABC):
521522
>>> # Define all Abstract methods.
522523
>>> ...

tiatoolbox/models/engine/semantic_segmentor_new.py

Lines changed: 285 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,85 @@
44

55
from typing import TYPE_CHECKING
66

7-
from .patch_predictor import PatchPredictor
7+
from typing_extensions import Unpack
8+
9+
from .patch_predictor import PatchPredictor, PredictorRunParams
810

911
if TYPE_CHECKING: # pragma: no cover
12+
import os
1013
from pathlib import Path
1114

15+
import numpy as np
16+
17+
from tiatoolbox.annotation import AnnotationStore
18+
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
1219
from tiatoolbox.models.models_abc import ModelABC
20+
from tiatoolbox.type_hints import Resolution
21+
from tiatoolbox.wsicore import WSIReader
22+
23+
24+
class SemanticSegmentorRunParams(PredictorRunParams):
25+
"""Class describing the input parameters for the :func:`EngineABC.run()` method.
26+
27+
Attributes:
28+
batch_size (int):
29+
Number of image patches to feed to the model in a forward pass.
30+
cache_mode (bool):
31+
Whether to run the Engine in cache_mode. For large datasets,
32+
we recommend to set this to True to avoid out of memory errors.
33+
For smaller datasets, the cache_mode is set to False as
34+
the results can be saved in memory.
35+
cache_size (int):
36+
Specifies how many image patches to process in a batch when
37+
cache_mode is set to True. If cache_size is less than the batch_size
38+
batch_size is set to cache_size.
39+
class_dict (dict):
40+
Optional dictionary mapping classification outputs to class names.
41+
device (str):
42+
Select the device to run the model. Please see
43+
https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
44+
for more details on input parameters for device.
45+
ioconfig (ModelIOConfigABC):
46+
Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine.
47+
return_labels (bool):
48+
Whether to return the labels with the predictions.
49+
num_loader_workers (int):
50+
Number of workers used in :class:`torch.utils.data.DataLoader`.
51+
num_post_proc_workers (int):
52+
Number of workers to postprocess the results of the model.
53+
output_file (str):
54+
Output file name to save "zarr" or "db". If None, path to output is
55+
returned by the engine.
56+
patch_input_shape (tuple):
57+
Shape of patches input to the model as tuple of height and width (HW).
58+
Patches are requested at read resolution, not with respect to level 0,
59+
and must be positive.
60+
resolution (Resolution):
61+
Resolution used for reading the image. Please see
62+
:class:`WSIReader` for details.
63+
return_probabilities (bool):
64+
Whether to return per-class probabilities.
65+
scale_factor (tuple[float, float]):
66+
The scale factor to use when loading the
67+
annotations. All coordinates will be multiplied by this factor to allow
68+
conversion of annotations saved at non-baseline resolution to baseline.
69+
Should be model_mpp/slide_mpp.
70+
stride_shape (tuple):
71+
Stride used during WSI processing. Stride is
72+
at requested read resolution, not with respect to
73+
level 0, and must be positive. If not provided,
74+
`stride_shape=patch_input_shape`.
75+
units (Units):
76+
Units of resolution used for reading the image. Choose
77+
from either `level`, `power` or `mpp`. Please see
78+
:class:`WSIReader` for details.
79+
verbose (bool):
80+
Whether to output logging information.
81+
82+
"""
83+
84+
patch_output_shape: tuple
85+
output_resolution: Resolution
1386

1487

1588
class SemanticSegmentor(PatchPredictor):
@@ -52,44 +125,128 @@ class SemanticSegmentor(PatchPredictor):
52125
Use externally defined PyTorch model for prediction with
53126
weights already loaded. Default is `None`. If provided,
54127
`pretrained_model` argument is ignored.
55-
pretrained_model (str):
56-
Name of the existing models support by tiatoolbox for
57-
processing the data. For a full list of pretrained models,
128+
batch_size (int):
129+
Number of images fed into the model each time.
130+
num_loader_workers (int):
131+
Number of workers to load the data using :class:`torch.utils.data.Dataset`.
132+
Please note that they will also perform preprocessing. Default value is 0.
133+
num_post_proc_workers (int):
134+
Number of workers to postprocess the results of the model.
135+
Default value is 0.
136+
weights (str or Path):
137+
Path to the weight of the corresponding `model`.
138+
139+
>>> engine = SemanticSegmentor(
140+
... model="pretrained-model",
141+
... weights="/path/to/pretrained-local-weights.pth"
142+
... )
143+
144+
verbose (bool):
145+
Whether to output logging information.
146+
device (str):
147+
Select the device to run the model. Please see
148+
https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
149+
for more details on input parameters for device. Default is "cpu".
150+
verbose (bool):
151+
Whether to output logging information. Default value is False.
152+
153+
Attributes:
154+
images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`):
155+
A list of image patches in NHWC format as a numpy array
156+
or a list of str/paths to WSIs.
157+
masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`):
158+
A list of tissue masks or binary masks corresponding to processing area of
159+
input images. These can be a list of numpy arrays or paths to
160+
the saved image masks. These are only utilized when patch_mode is False.
161+
Patches are only generated within a masked area.
162+
If not provided, then a tissue mask will be automatically
163+
generated for whole slide images.
164+
patch_mode (str):
165+
Whether to treat input images as a set of image patches. TIAToolbox defines
166+
an image as a patch if HWC of the input image matches with the HWC expected
167+
by the model. If HWC of the input image does not match with the HWC expected
168+
by the model, then the patch_mode must be set to False which will allow the
169+
engine to extract patches from the input image.
170+
In this case, when the patch_mode is False the input images are treated
171+
as WSIs. Default value is True.
172+
model (str | ModelABC):
173+
A PyTorch model or a name of an existing model from the TIAToolbox model zoo
174+
for processing the data. For a full list of pretrained models,
58175
refer to the `docs
59-
<https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_.
176+
<https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_
60177
By default, the corresponding pretrained weights will also
61178
be downloaded. However, you can override with your own set
62-
of weights via the `pretrained_weights` argument. Argument
179+
of weights via the `weights` argument. Argument
63180
is case-insensitive.
64-
pretrained_weights (str):
65-
Path to the weight of the corresponding `pretrained_model`.
181+
ioconfig (IOSegmentorConfig):
182+
Input IO configuration of type :class:`IOSegmentorConfig` to run the Engine.
183+
_ioconfig (IOSegmentorConfig):
184+
Runtime ioconfig.
185+
return_labels (bool):
186+
Whether to return the labels with the predictions.
187+
resolution (Resolution):
188+
Resolution used for reading the image. Please see
189+
:obj:`WSIReader` for details.
190+
units (Units):
191+
Units of resolution used for reading the image. Choose
192+
from either `level`, `power` or `mpp`. Please see
193+
:obj:`WSIReader` for details.
194+
patch_input_shape (tuple):
195+
Shape of patches input to the model as tupled of HW. Patches are at
196+
requested read resolution, not with respect to level 0,
197+
and must be positive.
198+
stride_shape (tuple):
199+
Stride used during WSI processing. Stride is
200+
at requested read resolution, not with respect to
201+
level 0, and must be positive. If not provided,
202+
`stride_shape=patch_input_shape`.
66203
batch_size (int):
67204
Number of images fed into the model each time.
205+
cache_mode (bool):
206+
Whether to run the Engine in cache_mode. For large datasets,
207+
we recommend to set this to True to avoid out of memory errors.
208+
For smaller datasets, the cache_mode is set to False as
209+
the results can be saved in memory. cache_mode is always True when
210+
processing WSIs i.e., when `patch_mode` is False. Default value is False.
211+
cache_size (int):
212+
Specifies how many image patches to process in a batch when
213+
cache_mode is set to True. If cache_size is less than the batch_size
214+
batch_size is set to cache_size. Default value is 10,000.
215+
labels (list | None):
216+
List of labels. Only a single label per image is supported.
217+
device (str):
218+
:class:`torch.device` to run the model.
219+
Select the device to run the model. Please see
220+
https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
221+
for more details on input parameters for device. Default value is "cpu".
68222
num_loader_workers (int):
69-
Number of workers to load the data. Take note that they will
70-
also perform preprocessing.
71-
num_postproc_workers (int):
72-
This value is there to maintain input compatibility with
73-
`tiatoolbox.models.classification` and is not used.
223+
Number of workers used in :class:`torch.utils.data.DataLoader`.
224+
num_post_proc_workers (int):
225+
Number of workers to postprocess the results of the model.
226+
return_labels (bool):
227+
Whether to return the output labels. Default value is False.
228+
resolution (Resolution):
229+
Resolution used for reading the image. Please see
230+
:class:`WSIReader` for details.
231+
When `patch_mode` is True, the input image patches are expected to be at
232+
the correct resolution and units. When `patch_mode` is False, the patches
233+
are extracted at the requested resolution and units. Default value is 1.0.
234+
units (Units):
235+
Units of resolution used for reading the image. Choose
236+
from either `baseline`, `level`, `power` or `mpp`. Please see
237+
:class:`WSIReader` for details.
238+
When `patch_mode` is True, the input image patches are expected to be at
239+
the correct resolution and units. When `patch_mode` is False, the patches
240+
are extracted at the requested resolution and units.
241+
Default value is `baseline`.
74242
verbose (bool):
75-
Whether to output logging information.
76-
dataset_class (obj):
77-
Dataset class to be used instead of default.
78-
auto_generate_mask (bool):
79-
To automatically generate tile/WSI tissue mask if is not
80-
provided.
81-
82-
Attributes:
83-
process_prediction_per_batch (bool):
84-
A flag to denote whether post-processing for inference
85-
output is applied after each batch or after finishing an entire
86-
tile or WSI.
243+
Whether to output logging information. Default value is False.
87244
88245
Examples:
89246
>>> # Sample output of a network
90247
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
91-
>>> predictor = SemanticSegmentor(model='fcn-tissue_mask')
92-
>>> output = predictor.predict(wsis, mode='wsi')
248+
>>> segmentor = SemanticSegmentor(model='fcn-tissue_mask')
249+
>>> output = segmentor.run(wsis, mode='wsi')
93250
>>> list(output.keys())
94251
[('A/wsi.svs', 'output/0.raw') , ('B/wsi.svs', 'output/1.raw')]
95252
>>> # if a network have 2 output heads, each head output of 'A/wsi.svs'
@@ -118,3 +275,104 @@ def __init__(
118275
device=device,
119276
verbose=verbose,
120277
)
278+
279+
def run(
280+
self: SemanticSegmentor,
281+
images: list[os | Path | WSIReader] | np.ndarray,
282+
masks: list[os | Path] | np.ndarray | None = None,
283+
labels: list | None = None,
284+
ioconfig: IOSegmentorConfig | None = None,
285+
*,
286+
patch_mode: bool = True,
287+
save_dir: os | Path | None = None, # None will not save output
288+
overwrite: bool = False,
289+
output_type: str = "dict",
290+
**kwargs: Unpack[SemanticSegmentorRunParams],
291+
) -> AnnotationStore | Path | str | dict:
292+
"""Run the engine on input images.
293+
294+
Args:
295+
images (list, ndarray):
296+
List of inputs to process. when using `patch` mode, the
297+
input must be either a list of images, a list of image
298+
file paths or a numpy array of an image list.
299+
masks (list | None):
300+
List of masks. Only utilised when patch_mode is False.
301+
Patches are only generated within a masked area.
302+
If not provided, then a tissue mask will be automatically
303+
generated for whole slide images.
304+
labels (list | None):
305+
List of labels. Only a single label per image is supported.
306+
patch_mode (bool):
307+
Whether to treat input image as a patch or WSI.
308+
default = True.
309+
ioconfig (IOSegmentorConfig):
310+
IO configuration.
311+
save_dir (str or pathlib.Path):
312+
Output directory to save the results.
313+
If save_dir is not provided when patch_mode is False,
314+
then for a single image the output is created in the current directory.
315+
If there are multiple WSIs as input then the user must provide
316+
path to save directory otherwise an OSError will be raised.
317+
overwrite (bool):
318+
Whether to overwrite the results. Default = False.
319+
output_type (str):
320+
The format of the output type. "output_type" can be
321+
"zarr" or "AnnotationStore". Default value is "zarr".
322+
When saving in the zarr format the output is saved using the
323+
`python zarr library <https://zarr.readthedocs.io/en/stable/>`__
324+
as a zarr group. If the required output type is an "AnnotationStore"
325+
then the output will be intermediately saved as zarr but converted
326+
to :class:`AnnotationStore` and saved as a `.db` file
327+
at the end of the loop.
328+
**kwargs (PredictorRunParams):
329+
Keyword Args to update :class:`EngineABC` attributes during runtime.
330+
331+
Returns:
332+
(:class:`numpy.ndarray`, dict):
333+
Model predictions of the input dataset. If multiple
334+
whole slide images are provided as input,
335+
or save_output is True, then results are saved to
336+
`save_dir` and a dictionary indicating save location for
337+
each input is returned.
338+
339+
The dict has the following format:
340+
341+
- img_path: path of the input image.
342+
- raw: path to save location for raw prediction,
343+
saved in .json.
344+
345+
Examples:
346+
>>> wsis = ['wsi1.svs', 'wsi2.svs']
347+
>>> image_patches = [np.ndarray, np.ndarray]
348+
>>> class SemanticSegmentor(PatchPredictor):
349+
>>> # Define all Abstract methods.
350+
>>> ...
351+
>>> segmentor = SemanticSegmentor(model="fcn-tissue_mask")
352+
>>> output = segmentor.run(image_patches, patch_mode=True)
353+
>>> output
354+
... "/path/to/Output.db"
355+
>>> output = segmentor.run(
356+
>>> image_patches,
357+
>>> patch_mode=True,
358+
>>> output_type="zarr")
359+
>>> output
360+
... "/path/to/Output.zarr"
361+
>>> output = segmentor.run(wsis, patch_mode=False)
362+
>>> output.keys()
363+
... ['wsi1.svs', 'wsi2.svs']
364+
>>> output['wsi1.svs']
365+
... {'/path/to/wsi1.db'}
366+
367+
"""
368+
return super().run(
369+
images=images,
370+
masks=masks,
371+
labels=labels,
372+
ioconfig=ioconfig,
373+
patch_mode=patch_mode,
374+
save_dir=save_dir,
375+
overwrite=overwrite,
376+
output_type=output_type,
377+
**kwargs,
378+
)

0 commit comments

Comments
 (0)