Skip to content

Commit 98cef83

Browse files
committed
remove smp dependency
1 parent 94c43ee commit 98cef83

File tree

7 files changed

+633
-537
lines changed

7 files changed

+633
-537
lines changed

requirements/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ requests>=2.28.1
2828
scikit-image>=0.20
2929
scikit-learn>=1.2.0
3030
scipy>=1.8
31-
segmentation-models-pytorch>=0.5.0
3231
shapely>=2.0.0
3332
SimpleITK>=2.2.1
3433
sphinx>=5.3.0

tests/models/test_arch_grandqc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
fetch_pretrained_weights,
88
get_pretrained_model,
99
)
10-
from tiatoolbox.models.architecture.grandqc import TissueDetectionModel
10+
from tiatoolbox.models.architecture.grandqc import GrandQCModel
1111
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
1212
from tiatoolbox.utils.misc import select_device
1313
from tiatoolbox.wsicore.wsireader import VirtualWSIReader
@@ -22,7 +22,7 @@ def test_functional_grandqc() -> None:
2222
assert pretrained_weights is not None
2323

2424
# test creation
25-
model = TissueDetectionModel(num_input_channels=3, num_output_channels=2)
25+
model = GrandQCModel(num_input_channels=3, num_output_channels=2)
2626
assert model is not None
2727

2828
# load pretrained weights
@@ -31,7 +31,7 @@ def test_functional_grandqc() -> None:
3131

3232
# test get pretrained model
3333
model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10")
34-
assert isinstance(model, TissueDetectionModel)
34+
assert isinstance(model, GrandQCModel)
3535
assert isinstance(ioconfig, IOSegmentorConfig)
3636
assert model.num_input_channels == 3
3737
assert model.num_output_channels == 2
@@ -54,7 +54,7 @@ def test_functional_grandqc() -> None:
5454

5555
def test_grandqc_preproc_postproc() -> None:
5656
"""Test GrandQC preproc and postproc functions."""
57-
model = TissueDetectionModel(num_input_channels=3, num_output_channels=2)
57+
model = GrandQCModel(num_input_channels=3, num_output_channels=2)
5858

5959
generator = np.random.default_rng(1337)
6060
# test preproc

tiatoolbox/data/pretrained_model.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ nuclick_light-pannuke:
938938
grandqc_tissue_detection_mpp10:
939939
hf_repo_id: TIACentre/GrandQC_Tissue_Detection
940940
architecture:
941-
class: grandqc.TissueDetectionModel
941+
class: grandqc.GrandQCModel
942942
kwargs:
943943
num_input_channels: 3
944944
num_output_channels: 2

tiatoolbox/models/architecture/grandqc.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
import cv2
1111
import numpy as np
12-
import segmentation_models_pytorch as smp
1312
import torch
1413

14+
from tiatoolbox.models.architecture.unetplusplus import UNetPlusPlusModel
1515
from tiatoolbox.models.models_abc import ModelABC
1616

1717

18-
class TissueDetectionModel(ModelABC):
19-
"""GrandQC Tissue Detection Model.
18+
class GrandQCModel(ModelABC):
19+
"""GrandQC Tissue Detection Model [1].
2020
2121
Example:
2222
>>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
@@ -32,28 +32,39 @@ class TissueDetectionModel(ModelABC):
3232
... output_type="annotationstore",
3333
... )
3434
35+
References:
36+
[1] Weng Z. et al. "GrandQC: a comprehensive solution to quality control problem
37+
in digital pathology".
38+
Nature Communications 2024
39+
3540
"""
3641

3742
def __init__(
38-
self: TissueDetectionModel, num_input_channels: int, num_output_channels: int
43+
self: GrandQCModel, num_input_channels: int, num_output_channels: int
3944
) -> None:
4045
"""Initialize TissueDetectionModel."""
4146
super().__init__()
4247
self.num_input_channels = num_input_channels
4348
self.num_output_channels = num_output_channels
4449
self._postproc = self.postproc
4550
self._preproc = self.preproc
46-
self.tissue_detection_model = smp.UnetPlusPlus(
47-
encoder_name="timm-efficientnet-b0",
48-
encoder_weights=None,
49-
in_channels=self.num_input_channels,
51+
self.tissue_detection_model = UNetPlusPlusModel(
5052
classes=self.num_output_channels,
51-
activation=None,
5253
)
5354

5455
@staticmethod
5556
def preproc(image: np.ndarray) -> np.ndarray:
56-
"""Apply jpg compression then ImageNet normalise."""
57+
"""Apply JPEG compression and ImageNet normalization to the input image.
58+
59+
Args:
60+
image (np.ndarray):
61+
Input image as a NumPy array (H, W, C) in uint8 format.
62+
63+
Returns:
64+
np.ndarray:
65+
The preprocessed image.
66+
67+
"""
5768
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
5869
_, compressed_image = cv2.imencode(".jpg", image, encode_param)
5970
compressed_image = np.array(cv2.imdecode(compressed_image, 1))
@@ -69,11 +80,19 @@ def postproc(image: np.ndarray) -> np.ndarray:
6980
This simply applies argmin to obtain tissue class.
7081
(Tissue = 0, Background = 1)
7182
83+
Args:
84+
image (np.ndarray):
85+
Input probability map as a NumPy array (H, W, C).
86+
87+
Returns:
88+
np.ndarray:
89+
Tissue mask
90+
7291
"""
7392
return image.argmin(axis=-1)
7493

7594
def forward(
76-
self: TissueDetectionModel,
95+
self: GrandQCModel,
7796
imgs: torch.Tensor,
7897
*args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002
7998
**kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002
@@ -120,9 +139,9 @@ def infer_batch(
120139
return probs.cpu().numpy()
121140

122141
def load_state_dict(
123-
self: TissueDetectionModel,
142+
self: GrandQCModel,
124143
state_dict: Mapping[str, Any],
125144
**kwargs: bool,
126145
) -> torch.nn.modules.module._IncompatibleKeys:
127-
"""Load state dict for the TissueDetectionModel."""
146+
"""Load state dict for the GrandQCModel."""
128147
return self.tissue_detection_model.load_state_dict(state_dict, **kwargs)

0 commit comments

Comments
 (0)