Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b18b98f
add grandqc tissue model
Jiaqi-Lv Oct 25, 2025
899d6cb
add example
Jiaqi-Lv Oct 25, 2025
8a7295d
fix tests
Jiaqi-Lv Oct 25, 2025
5c5bfc4
fix error
Jiaqi-Lv Oct 25, 2025
fd692da
update docstring
Jiaqi-Lv Oct 28, 2025
d82cc3d
improve test coverage
Jiaqi-Lv Oct 28, 2025
93a24a1
add unet++ model
Jiaqi-Lv Nov 6, 2025
2d076c0
Merge branch 'dev-define-engines-abc' into dev-add-grandQC
shaneahmed Nov 17, 2025
283b888
Merge branch 'dev-add-grandQC' of https://github.com/TissueImageAnaly…
Jiaqi-Lv Nov 18, 2025
94c43ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
98cef83
remove smp dependency
Jiaqi-Lv Nov 18, 2025
d47fa0a
refactor code
Jiaqi-Lv Nov 18, 2025
d2a66ca
add tests
Jiaqi-Lv Nov 21, 2025
19cca90
address comments
Jiaqi-Lv Nov 21, 2025
1895e38
:memo: Update docstring for grandqc.py and timm_efficientnet.py
shaneahmed Nov 25, 2025
3ade99a
:bug: Fix docstring
shaneahmed Nov 25, 2025
5f0202f
:memo: Remove duplicate docstring for classses.
shaneahmed Nov 25, 2025
6b8eb90
address comments
Jiaqi-Lv Nov 25, 2025
2ce379f
update test
Jiaqi-Lv Nov 25, 2025
9c62b72
:white_check_mark: Add test to improve coverage
shaneahmed Nov 26, 2025
d1ce4a0
improve test coverage
Jiaqi-Lv Nov 27, 2025
717b0ff
improve test coverage
Jiaqi-Lv Nov 27, 2025
3cc5924
address comments
Jiaqi-Lv Nov 28, 2025
1ab6728
:fire: Remove unnecessary checks
shaneahmed Dec 1, 2025
3a27ed6
:bug: Fix incorrect input for bias
shaneahmed Dec 1, 2025
cc4499a
:art: Improve structure of the code.
shaneahmed Dec 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ requests>=2.28.1
scikit-image>=0.20
scikit-learn>=1.2.0
scipy>=1.8
segmentation-models-pytorch>=0.5.0
shapely>=2.0.0
SimpleITK>=2.2.1
sphinx>=5.3.0
Expand Down
54 changes: 54 additions & 0 deletions tests/models/test_arch_grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Unit test package for GrandQC Tissue Model."""

from collections.abc import Callable
from pathlib import Path

import numpy as np
import torch

from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.architecture.grandqc import TissueDetectionModel
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = False


def test_functional_grandqc(remote_sample: Callable) -> None:
"""Test for GrandQC model."""
# test fetch pretrained weights
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10")
assert pretrained_weights is not None

# test creation
model = TissueDetectionModel(num_input_channels=3, num_output_channels=2)
assert model is not None

# load pretrained weights
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)

# test get pretrained model
model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10")
assert isinstance(model, TissueDetectionModel)
assert isinstance(ioconfig, IOSegmentorConfig)
assert model.num_input_channels == 3
assert model.num_output_channels == 2

# test inference
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))
reader = WSIReader.open(mini_wsi_svs)
read_kwargs = {"resolution": 10.0, "units": "mpp", "coord_space": "resolution"}
batch = np.array(
[
reader.read_bounds((0, 0, 512, 512), **read_kwargs),
reader.read_bounds((512, 512, 1024, 1024), **read_kwargs),
],
)
batch = torch.from_numpy(batch)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
assert output.shape == (2, 512, 512, 2)
31 changes: 25 additions & 6 deletions tiatoolbox/data/pretrained_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ mapde-crchisto:
threshold_abs: 250
num_classes: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -837,7 +837,7 @@ mapde-conic:
threshold_abs: 205
num_classes: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -860,7 +860,7 @@ sccnn-crchisto:
threshold_abs: 0.20
patch_output_shape: [ 13, 13 ]
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -883,7 +883,7 @@ sccnn-conic:
threshold_abs: 0.05
patch_output_shape: [ 13, 13 ]
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -903,7 +903,7 @@ nuclick_original-pannuke:
num_input_channels: 5
num_output_channels: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 0.25}
Expand All @@ -925,7 +925,7 @@ nuclick_light-pannuke:
decoder_block: [3,3]
skip_type: "add"
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 0.25}
Expand All @@ -934,3 +934,22 @@ nuclick_light-pannuke:
patch_input_shape: [128, 128]
patch_output_shape: [128, 128]
save_resolution: {'units': 'baseline', 'resolution': 1.0}

grandqc_tissue_detection_mpp10:
hf_repo_id: TIACentre/GrandQC_Tissue_Detection
architecture:
class: grandqc.TissueDetectionModel
kwargs:
num_input_channels: 3
num_output_channels: 2
ioconfig:
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'mpp', 'resolution': 10.0}
output_resolutions:
- {'units': 'mpp', 'resolution': 10.0}
patch_input_shape: [512, 512]
patch_output_shape: [512, 512]
stride_shape: [256, 256]
save_resolution: {'units': 'mpp', 'resolution': 10.0}
130 changes: 130 additions & 0 deletions tiatoolbox/models/architecture/grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Define GrandQC Tissue Detection Model architecture."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Mapping

import cv2
import numpy as np
import segmentation_models_pytorch as smp
import torch

from tiatoolbox.models.models_abc import ModelABC


class TissueDetectionModel(ModelABC):
"""GrandQC Tissue Detection Model.

Example:
>>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
>>> semantic_segmentor = SemanticSegmentor(
... model="grandqc_tissue_detection_mpp10",
... )
>>> results = semantic_segmentor.run(
... ["/example_wsi.svs"],
... masks=None,
... auto_get_mask=False,
... patch_mode=False,
... save_dir=Path("/tissue_mask/"),
... output_type="annotationstore",
... )

"""

def __init__(
self: TissueDetectionModel, num_input_channels: int, num_output_channels: int
) -> None:
"""Initialize TissueDetectionModel."""
super().__init__()
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self._postproc = self.postproc
self._preproc = self.preproc
self.tissue_detection_model = smp.UnetPlusPlus(
encoder_name="timm-efficientnet-b0",
encoder_weights=None,
in_channels=self.num_input_channels,
classes=self.num_output_channels,
activation=None,
)

@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Apply jpg compression then ImageNet normalise."""
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
_, compressed_image = cv2.imencode(".jpg", image, encode_param)
compressed_image = np.array(cv2.imdecode(compressed_image, 1))

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
return (compressed_image / 255.0 - mean) / std

@staticmethod
def postproc(image: np.ndarray) -> np.ndarray:
"""Define post-processing for this model.

This returns the class index with the minimum probability.
In this model, this means selecting tissue class.

"""
return image.argmin(axis=-1)

def forward(
self: TissueDetectionModel,
imgs: torch.Tensor,
*args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002
**kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002
) -> torch.Tensor:
"""Forward function for model."""
return self.tissue_detection_model(imgs)

@staticmethod
def infer_batch(
model: torch.nn.Module,
batch_data: torch.Tensor,
*,
device: str,
) -> np.ndarray:
"""Run inference on an input batch.

This contains logic for forward operation as well as i/o

Args:
model (nn.Module):
PyTorch defined model.
batch_data (:class:`torch.Tensor`):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

Returns:
np.ndarray:
The inference results as a numpy array.

"""
model.eval()

####
imgs = batch_data

imgs = imgs.to(device).type(torch.float32)
imgs = imgs.permute(0, 3, 1, 2) # to NCHW

with torch.inference_mode():
logits = model(imgs)
probs = torch.nn.functional.softmax(logits, 1)
probs = probs.permute(0, 2, 3, 1) # to NHWC

return probs.cpu().numpy()

def load_state_dict(
self: TissueDetectionModel,
state_dict: Mapping[str, Any],
**kwargs: bool,
) -> torch.nn.modules.module._IncompatibleKeys:
"""Load state dict for the TissueDetectionModel."""
return self.tissue_detection_model.load_state_dict(state_dict, **kwargs)
Loading