diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 045a4ce4e..d99b62e10 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -28,6 +28,7 @@ requests>=2.28.1 scikit-image>=0.20 scikit-learn>=1.2.0 scipy>=1.8 +segmentation-models-pytorch>=0.5.0 shapely>=2.0.0 SimpleITK>=2.2.1 sphinx>=5.3.0 diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py new file mode 100644 index 000000000..1c124fc09 --- /dev/null +++ b/tests/models/test_arch_grandqc.py @@ -0,0 +1,70 @@ +"""Unit test package for GrandQC Tissue Model.""" + +import numpy as np +import torch + +from tiatoolbox.models.architecture import ( + fetch_pretrained_weights, + get_pretrained_model, +) +from tiatoolbox.models.architecture.grandqc import TissueDetectionModel +from tiatoolbox.models.engine.io_config import IOSegmentorConfig +from tiatoolbox.utils.misc import select_device +from tiatoolbox.wsicore.wsireader import VirtualWSIReader + +ON_GPU = False + + +def test_functional_grandqc() -> None: + """Test for GrandQC model.""" + # test fetch pretrained weights + pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10") + assert pretrained_weights is not None + + # test creation + model = TissueDetectionModel(num_input_channels=3, num_output_channels=2) + assert model is not None + + # load pretrained weights + pretrained = torch.load(pretrained_weights, map_location="cpu") + model.load_state_dict(pretrained) + + # test get pretrained model + model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10") + assert isinstance(model, TissueDetectionModel) + assert isinstance(ioconfig, IOSegmentorConfig) + assert model.num_input_channels == 3 + assert model.num_output_channels == 2 + + # test inference + generator = np.random.default_rng(1337) + test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8) + reader = VirtualWSIReader.open(test_image) + read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"} + batch = np.array( + [ + reader.read_bounds((0, 0, 512, 512), **read_kwargs), + reader.read_bounds((512, 512, 1024, 1024), **read_kwargs), + ], + ) + batch = torch.from_numpy(batch) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) + assert output.shape == (2, 512, 512, 2) + + +def test_grandqc_preproc_postproc() -> None: + """Test GrandQC preproc and postproc functions.""" + model = TissueDetectionModel(num_input_channels=3, num_output_channels=2) + + generator = np.random.default_rng(1337) + # test preproc + dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8) + preproc_image = model.preproc(dummy_image) + assert preproc_image.shape == dummy_image.shape + assert preproc_image.dtype == np.float64 + + # test postproc + dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32) + postproc_image = model.postproc(dummy_output) + assert postproc_image.shape == (512, 512) + assert postproc_image.dtype == np.int64 diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 880c623fe..434ec936a 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -815,7 +815,7 @@ mapde-crchisto: threshold_abs: 250 num_classes: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -837,7 +837,7 @@ mapde-conic: threshold_abs: 205 num_classes: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -860,7 +860,7 @@ sccnn-crchisto: threshold_abs: 0.20 patch_output_shape: [ 13, 13 ] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -883,7 +883,7 @@ sccnn-conic: threshold_abs: 0.05 patch_output_shape: [ 13, 13 ] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -903,7 +903,7 @@ nuclick_original-pannuke: num_input_channels: 5 num_output_channels: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 0.25} @@ -925,7 +925,7 @@ nuclick_light-pannuke: decoder_block: [3,3] skip_type: "add" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 0.25} @@ -934,3 +934,22 @@ nuclick_light-pannuke: patch_input_shape: [128, 128] patch_output_shape: [128, 128] save_resolution: {'units': 'baseline', 'resolution': 1.0} + +grandqc_tissue_detection_mpp10: + hf_repo_id: TIACentre/GrandQC_Tissue_Detection + architecture: + class: grandqc.TissueDetectionModel + kwargs: + num_input_channels: 3 + num_output_channels: 2 + ioconfig: + class: io_config.IOSegmentorConfig + kwargs: + input_resolutions: + - {'units': 'mpp', 'resolution': 10.0} + output_resolutions: + - {'units': 'mpp', 'resolution': 10.0} + patch_input_shape: [512, 512] + patch_output_shape: [512, 512] + stride_shape: [256, 256] + save_resolution: {'units': 'mpp', 'resolution': 10.0} diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py new file mode 100644 index 000000000..a2ffdf2db --- /dev/null +++ b/tiatoolbox/models/architecture/grandqc.py @@ -0,0 +1,128 @@ +"""Define GrandQC Tissue Detection Model architecture.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Mapping + +import cv2 +import numpy as np +import segmentation_models_pytorch as smp +import torch + +from tiatoolbox.models.models_abc import ModelABC + + +class TissueDetectionModel(ModelABC): + """GrandQC Tissue Detection Model. + + Example: + >>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor + >>> semantic_segmentor = SemanticSegmentor( + ... model="grandqc_tissue_detection_mpp10", + ... ) + >>> results = semantic_segmentor.run( + ... ["/example_wsi.svs"], + ... masks=None, + ... auto_get_mask=False, + ... patch_mode=False, + ... save_dir=Path("/tissue_mask/"), + ... output_type="annotationstore", + ... ) + + """ + + def __init__( + self: TissueDetectionModel, num_input_channels: int, num_output_channels: int + ) -> None: + """Initialize TissueDetectionModel.""" + super().__init__() + self.num_input_channels = num_input_channels + self.num_output_channels = num_output_channels + self._postproc = self.postproc + self._preproc = self.preproc + self.tissue_detection_model = smp.UnetPlusPlus( + encoder_name="timm-efficientnet-b0", + encoder_weights=None, + in_channels=self.num_input_channels, + classes=self.num_output_channels, + activation=None, + ) + + @staticmethod + def preproc(image: np.ndarray) -> np.ndarray: + """Apply jpg compression then ImageNet normalise.""" + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] + _, compressed_image = cv2.imencode(".jpg", image, encode_param) + compressed_image = np.array(cv2.imdecode(compressed_image, 1)) + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + return (compressed_image / 255.0 - mean) / std + + @staticmethod + def postproc(image: np.ndarray) -> np.ndarray: + """Define post-processing for this model. + + This simply applies argmin to obtain tissue class. + (Tissue = 0, Background = 1) + + """ + return image.argmin(axis=-1) + + def forward( + self: TissueDetectionModel, + imgs: torch.Tensor, + *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 + **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 + ) -> torch.Tensor: + """Forward function for model.""" + return self.tissue_detection_model(imgs) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: torch.Tensor, + *, + device: str, + ) -> np.ndarray: + """Run inference on an input batch. + + This contains logic for forward operation as well as i/o + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (:class:`torch.Tensor`): + A batch of data generated by + `torch.utils.data.DataLoader`. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + np.ndarray: + The inference results as a numpy array. + + """ + model.eval() + + imgs = batch_data + imgs = imgs.to(device).type(torch.float32) + imgs = imgs.permute(0, 3, 1, 2) # to NCHW + + with torch.inference_mode(): + logits = model(imgs) + probs = torch.nn.functional.softmax(logits, 1) + probs = probs.permute(0, 2, 3, 1) # to NHWC + + return probs.cpu().numpy() + + def load_state_dict( + self: TissueDetectionModel, + state_dict: Mapping[str, Any], + **kwargs: bool, + ) -> torch.nn.modules.module._IncompatibleKeys: + """Load state dict for the TissueDetectionModel.""" + return self.tissue_detection_model.load_state_dict(state_dict, **kwargs)