Skip to content

Commit 6b8eb90

Browse files
committed
address comments
1 parent 5f0202f commit 6b8eb90

File tree

3 files changed

+64
-22
lines changed

3 files changed

+64
-22
lines changed

tests/models/test_arch_grandqc.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Unit test package for GrandQC Tissue Model."""
22

3+
from collections.abc import Callable
4+
from pathlib import Path
5+
36
import numpy as np
47
import torch
58
from torch import nn
69

10+
from tiatoolbox.annotation.storage import SQLiteStore
711
from tiatoolbox.models.architecture import (
812
fetch_pretrained_weights,
913
get_pretrained_model,
@@ -15,16 +19,17 @@
1519
UnetPlusPlusDecoder,
1620
)
1721
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
18-
from tiatoolbox.utils.misc import select_device
22+
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
23+
from tiatoolbox.utils import env_detection as toolbox_env
1924
from tiatoolbox.wsicore.wsireader import VirtualWSIReader
2025

21-
ON_GPU = False
26+
device = "cuda" if toolbox_env.has_gpu() else "cpu"
2227

2328

2429
def test_functional_grandqc() -> None:
2530
"""Test for GrandQC model."""
2631
# test fetch pretrained weights
27-
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10")
32+
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection")
2833
assert pretrained_weights is not None
2934

3035
# test creation
@@ -36,7 +41,7 @@ def test_functional_grandqc() -> None:
3641
model.load_state_dict(pretrained)
3742

3843
# test get pretrained model
39-
model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10")
44+
model, ioconfig = get_pretrained_model("grandqc_tissue_detection")
4045
assert isinstance(model, GrandQCModel)
4146
assert isinstance(ioconfig, IOSegmentorConfig)
4247
assert model.num_output_channels == 2
@@ -54,7 +59,7 @@ def test_functional_grandqc() -> None:
5459
],
5560
)
5661
batch = torch.from_numpy(batch)
57-
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
62+
output = model.infer_batch(model, batch, device=device)
5863
assert output.shape == (2, 512, 512, 2)
5964

6065

@@ -76,6 +81,39 @@ def test_grandqc_preproc_postproc() -> None:
7681
assert postproc_image.dtype == np.int64
7782

7883

84+
def test_grandqc_with_semantic_segmentor(
85+
remote_sample: Callable, track_tmp_path: Path
86+
) -> None:
87+
"""Test GrandQC tissue mask generation."""
88+
segmentor = SemanticSegmentor(model="grandqc_tissue_detection")
89+
90+
sample_image = remote_sample("svs-1-small")
91+
inputs = [str(sample_image)]
92+
93+
output = segmentor.run(
94+
images=inputs,
95+
device=device,
96+
patch_mode=False,
97+
output_type="annotationstore",
98+
save_dir=track_tmp_path / "grandqc_test_outputs",
99+
overwrite=True,
100+
)
101+
102+
assert len(output) == 1
103+
assert Path(output[sample_image]).exists()
104+
105+
store = SQLiteStore.open(output[sample_image])
106+
assert len(store) == 3
107+
108+
tissue_area_px = 0.0
109+
for annotation in store.values():
110+
assert annotation.properties["type"] == "mask"
111+
tissue_area_px += annotation.geometry.area
112+
assert 3003401 < tissue_area_px < 3003402
113+
114+
store.close()
115+
116+
79117
def test_segmentation_head_behaviour() -> None:
80118
"""Verify SegmentationHead defaults and upsampling."""
81119
head = SegmentationHead(3, 5, activation=None, upsampling=1)

tiatoolbox/data/pretrained_model.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ nuclick_light-pannuke:
935935
patch_output_shape: [128, 128]
936936
save_resolution: {'units': 'baseline', 'resolution': 1.0}
937937

938-
grandqc_tissue_detection_mpp10:
938+
grandqc_tissue_detection:
939939
hf_repo_id: TIACentre/GrandQC_Tissue_Detection
940940
architecture:
941941
class: grandqc.GrandQCModel

tiatoolbox/models/architecture/grandqc.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""GrandQC Tissue Detection Model Architecture.
22
33
This module defines the GrandQC model for tissue detection in digital pathology.
4-
It implements a UNet++ architecture with an EfficientNet encoder and a segmentation
4+
It implements a UNet++ architecture with an EfficientNetB0 encoder and a segmentation
55
head for high-resolution tissue segmentation. The model is designed to identify
66
tissue regions and background areas for quality control in whole slide images (WSIs).
77
@@ -205,7 +205,7 @@ class DecoderBlock(nn.Module):
205205
206206
This block performs upsampling and feature fusion using skip connections
207207
from the encoder. It consists of two convolutional layers with ReLU activation
208-
and optional attention mechanisms.
208+
and optional attention mechanisms (not implemented).
209209
210210
Attributes:
211211
conv1 (Conv2dReLU):
@@ -222,9 +222,9 @@ class DecoderBlock(nn.Module):
222222
223223
Example:
224224
>>> block = DecoderBlock(in_channels=128, skip_channels=64, out_channels=64)
225-
>>> x = torch.randn(1, 128, 64, 64)
225+
>>> input_tensor = torch.randn(1, 128, 64, 64)
226226
>>> skip = torch.randn(1, 64, 128, 128)
227-
>>> output = block(x, skip)
227+
>>> output = block(input_tensor, skip)
228228
>>> output.shape
229229
... torch.Size([1, 64, 128, 128])
230230
@@ -268,7 +268,7 @@ def __init__(
268268

269269
def forward(
270270
self: DecoderBlock,
271-
x: torch.Tensor,
271+
input_tensor: torch.Tensor,
272272
skip: torch.Tensor | None = None,
273273
) -> torch.Tensor:
274274
"""Forward pass through the decoder block.
@@ -277,29 +277,33 @@ def forward(
277277
(if provided), and applies two convolutional layers with attention.
278278
279279
Args:
280-
x (torch.Tensor):
281-
Input tensor from the previous decoder layer.
280+
input_tensor (torch.Tensor):
281+
(B, C_in, H, W). Input tensor from the previous decoder layer.
282282
skip (torch.Tensor | None):
283+
(B, C_skip, H*2, W*2).
283284
Skip connection tensor from the encoder. Defaults to None.
284285
285286
Returns:
286287
torch.Tensor:
288+
(B, C_out, H*2, W*2).
287289
Output tensor after decoding and feature refinement.
288290
289291
"""
290-
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
292+
input_tensor = torch.nn.functional.interpolate(
293+
input_tensor, scale_factor=2.0, mode="nearest"
294+
)
291295
if skip is not None:
292-
x = torch.cat([x, skip], dim=1)
293-
x = self.attention1(x)
294-
x = self.conv1(x)
295-
x = self.conv2(x)
296-
return self.attention2(x)
296+
input_tensor = torch.cat([input_tensor, skip], dim=1)
297+
input_tensor = self.attention1(input_tensor)
298+
input_tensor = self.conv1(input_tensor)
299+
input_tensor = self.conv2(input_tensor)
300+
return self.attention2(input_tensor)
297301

298302

299303
class CenterBlock(nn.Sequential):
300304
"""Center block for UNet++ architecture.
301305
302-
This block is placed at the bottleneck of the UNet++ architecture.
306+
This block can be placed at the bottleneck of the UNet++ architecture.
303307
It consists of two convolutional layers with ReLU activation, used
304308
to process the deepest feature maps before decoding begins.
305309
@@ -311,8 +315,8 @@ class CenterBlock(nn.Sequential):
311315
312316
Example:
313317
>>> center = CenterBlock(in_channels=256, out_channels=512)
314-
>>> x = torch.randn(1, 256, 32, 32)
315-
>>> output = center(x)
318+
>>> input_tensor = torch.randn(1, 256, 32, 32)
319+
>>> output = center(input_tensor)
316320
>>> output.shape
317321
... torch.Size([1, 512, 32, 32])
318322

0 commit comments

Comments
 (0)