Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ requests>=2.28.1
scikit-image>=0.20
scikit-learn>=1.2.0
scipy>=1.8
segmentation-models-pytorch>=0.5.0
shapely>=2.0.0
SimpleITK>=2.2.1
sphinx>=5.3.0
Expand Down
70 changes: 70 additions & 0 deletions tests/models/test_arch_grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Unit test package for GrandQC Tissue Model."""

import numpy as np
import torch

from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.architecture.grandqc import TissueDetectionModel
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import VirtualWSIReader

ON_GPU = False


def test_functional_grandqc() -> None:
"""Test for GrandQC model."""
# test fetch pretrained weights
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10")
assert pretrained_weights is not None

# test creation
model = TissueDetectionModel(num_input_channels=3, num_output_channels=2)
assert model is not None

# load pretrained weights
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)

# test get pretrained model
model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10")
assert isinstance(model, TissueDetectionModel)
assert isinstance(ioconfig, IOSegmentorConfig)
assert model.num_input_channels == 3
assert model.num_output_channels == 2

# test inference
generator = np.random.default_rng(1337)
test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8)
reader = VirtualWSIReader.open(test_image)
read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"}
batch = np.array(
[
reader.read_bounds((0, 0, 512, 512), **read_kwargs),
reader.read_bounds((512, 512, 1024, 1024), **read_kwargs),
],
)
batch = torch.from_numpy(batch)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
assert output.shape == (2, 512, 512, 2)


def test_grandqc_preproc_postproc() -> None:
"""Test GrandQC preproc and postproc functions."""
model = TissueDetectionModel(num_input_channels=3, num_output_channels=2)

generator = np.random.default_rng(1337)
# test preproc
dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8)
preproc_image = model.preproc(dummy_image)
assert preproc_image.shape == dummy_image.shape
assert preproc_image.dtype == np.float64

# test postproc
dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32)
postproc_image = model.postproc(dummy_output)
assert postproc_image.shape == (512, 512)
assert postproc_image.dtype == np.int64
31 changes: 25 additions & 6 deletions tiatoolbox/data/pretrained_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ mapde-crchisto:
threshold_abs: 250
num_classes: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -837,7 +837,7 @@ mapde-conic:
threshold_abs: 205
num_classes: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -860,7 +860,7 @@ sccnn-crchisto:
threshold_abs: 0.20
patch_output_shape: [ 13, 13 ]
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -883,7 +883,7 @@ sccnn-conic:
threshold_abs: 0.05
patch_output_shape: [ 13, 13 ]
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -903,7 +903,7 @@ nuclick_original-pannuke:
num_input_channels: 5
num_output_channels: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 0.25}
Expand All @@ -925,7 +925,7 @@ nuclick_light-pannuke:
decoder_block: [3,3]
skip_type: "add"
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 0.25}
Expand All @@ -934,3 +934,22 @@ nuclick_light-pannuke:
patch_input_shape: [128, 128]
patch_output_shape: [128, 128]
save_resolution: {'units': 'baseline', 'resolution': 1.0}

grandqc_tissue_detection_mpp10:
hf_repo_id: TIACentre/GrandQC_Tissue_Detection
architecture:
class: grandqc.TissueDetectionModel
kwargs:
num_input_channels: 3
num_output_channels: 2
ioconfig:
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'mpp', 'resolution': 10.0}
output_resolutions:
- {'units': 'mpp', 'resolution': 10.0}
patch_input_shape: [512, 512]
patch_output_shape: [512, 512]
stride_shape: [256, 256]
save_resolution: {'units': 'mpp', 'resolution': 10.0}
128 changes: 128 additions & 0 deletions tiatoolbox/models/architecture/grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Define GrandQC Tissue Detection Model architecture."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING: # pragma: no cover
from collections.abc import Mapping

import cv2
import numpy as np
import segmentation_models_pytorch as smp
import torch

from tiatoolbox.models.models_abc import ModelABC


class TissueDetectionModel(ModelABC):
"""GrandQC Tissue Detection Model.

Example:
>>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
>>> semantic_segmentor = SemanticSegmentor(
... model="grandqc_tissue_detection_mpp10",
... )
>>> results = semantic_segmentor.run(
... ["/example_wsi.svs"],
... masks=None,
... auto_get_mask=False,
... patch_mode=False,
... save_dir=Path("/tissue_mask/"),
... output_type="annotationstore",
... )

"""

def __init__(
self: TissueDetectionModel, num_input_channels: int, num_output_channels: int
) -> None:
"""Initialize TissueDetectionModel."""
super().__init__()
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self._postproc = self.postproc
self._preproc = self.preproc
self.tissue_detection_model = smp.UnetPlusPlus(
encoder_name="timm-efficientnet-b0",
encoder_weights=None,
in_channels=self.num_input_channels,
classes=self.num_output_channels,
activation=None,
)

@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Apply jpg compression then ImageNet normalise."""
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'normalise' to 'normalize' for consistency with American English spelling used elsewhere in the codebase.

Suggested change
"""Apply jpg compression then ImageNet normalise."""
"""Apply jpg compression then ImageNet normalize."""

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The preproc docstring lacks parameter and return type documentation. Consider adding a complete docstring with Args and Returns sections following the pattern used in other models like HoVerNet and MicroNet.

Suggested change
"""Apply jpg compression then ImageNet normalise."""
"""Apply JPEG compression and ImageNet normalization to the input image.
Args:
image (np.ndarray):
Input image as a NumPy array (H, W, C) in uint8 format.
Returns:
np.ndarray:
The preprocessed image as a float32 NumPy array, normalized using ImageNet mean and std.
"""

Copilot uses AI. Check for mistakes.
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
_, compressed_image = cv2.imencode(".jpg", image, encode_param)
compressed_image = np.array(cv2.imdecode(compressed_image, 1))

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
return (compressed_image / 255.0 - mean) / std

@staticmethod
def postproc(image: np.ndarray) -> np.ndarray:
"""Define post-processing for this model.

This simply applies argmin to obtain tissue class.
(Tissue = 0, Background = 1)

"""
return image.argmin(axis=-1)

def forward(
self: TissueDetectionModel,
imgs: torch.Tensor,
*args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002
**kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002
) -> torch.Tensor:
"""Forward function for model."""
return self.tissue_detection_model(imgs)

@staticmethod
def infer_batch(
model: torch.nn.Module,
batch_data: torch.Tensor,
*,
device: str,
) -> np.ndarray:
"""Run inference on an input batch.

This contains logic for forward operation as well as i/o

Args:
model (nn.Module):
PyTorch defined model.
batch_data (:class:`torch.Tensor`):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

Returns:
np.ndarray:
The inference results as a numpy array.

"""
model.eval()

imgs = batch_data
imgs = imgs.to(device).type(torch.float32)
imgs = imgs.permute(0, 3, 1, 2) # to NCHW

with torch.inference_mode():
logits = model(imgs)
probs = torch.nn.functional.softmax(logits, 1)
probs = probs.permute(0, 2, 3, 1) # to NHWC

return probs.cpu().numpy()

def load_state_dict(
self: TissueDetectionModel,
state_dict: Mapping[str, Any],
**kwargs: bool,
) -> torch.nn.modules.module._IncompatibleKeys:
"""Load state dict for the TissueDetectionModel."""
return self.tissue_detection_model.load_state_dict(state_dict, **kwargs)
Loading