99
1010import cv2
1111import numpy as np
12- import segmentation_models_pytorch as smp
1312import torch
1413
14+ from tiatoolbox .models .architecture .unetplusplus import UNetPlusPlusModel
1515from tiatoolbox .models .models_abc import ModelABC
1616
1717
18- class TissueDetectionModel (ModelABC ):
19- """GrandQC Tissue Detection Model.
18+ class GrandQCModel (ModelABC ):
19+ """GrandQC Tissue Detection Model [1] .
2020
2121 Example:
2222 >>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
@@ -32,28 +32,39 @@ class TissueDetectionModel(ModelABC):
3232 ... output_type="annotationstore",
3333 ... )
3434
35+ References:
36+ [1] Weng Z. et al. "GrandQC: a comprehensive solution to quality control problem
37+ in digital pathology".
38+ Nature Communications 2024
39+
3540 """
3641
3742 def __init__ (
38- self : TissueDetectionModel , num_input_channels : int , num_output_channels : int
43+ self : GrandQCModel , num_input_channels : int , num_output_channels : int
3944 ) -> None :
4045 """Initialize TissueDetectionModel."""
4146 super ().__init__ ()
4247 self .num_input_channels = num_input_channels
4348 self .num_output_channels = num_output_channels
4449 self ._postproc = self .postproc
4550 self ._preproc = self .preproc
46- self .tissue_detection_model = smp .UnetPlusPlus (
47- encoder_name = "timm-efficientnet-b0" ,
48- encoder_weights = None ,
49- in_channels = self .num_input_channels ,
51+ self .tissue_detection_model = UNetPlusPlusModel (
5052 classes = self .num_output_channels ,
51- activation = None ,
5253 )
5354
5455 @staticmethod
5556 def preproc (image : np .ndarray ) -> np .ndarray :
56- """Apply jpg compression then ImageNet normalise."""
57+ """Apply JPEG compression and ImageNet normalization to the input image.
58+
59+ Args:
60+ image (np.ndarray):
61+ Input image as a NumPy array (H, W, C) in uint8 format.
62+
63+ Returns:
64+ np.ndarray:
65+ The preprocessed image.
66+
67+ """
5768 encode_param = [int (cv2 .IMWRITE_JPEG_QUALITY ), 80 ]
5869 _ , compressed_image = cv2 .imencode (".jpg" , image , encode_param )
5970 compressed_image = np .array (cv2 .imdecode (compressed_image , 1 ))
@@ -69,11 +80,19 @@ def postproc(image: np.ndarray) -> np.ndarray:
6980 This simply applies argmin to obtain tissue class.
7081 (Tissue = 0, Background = 1)
7182
83+ Args:
84+ image (np.ndarray):
85+ Input probability map as a NumPy array (H, W, C).
86+
87+ Returns:
88+ np.ndarray:
89+ Tissue mask
90+
7291 """
7392 return image .argmin (axis = - 1 )
7493
7594 def forward (
76- self : TissueDetectionModel ,
95+ self : GrandQCModel ,
7796 imgs : torch .Tensor ,
7897 * args : tuple [Any , ...], # skipcq: PYL-W0613 # noqa: ARG002
7998 ** kwargs : dict , # skipcq: PYL-W0613 # noqa: ARG002
@@ -120,9 +139,9 @@ def infer_batch(
120139 return probs .cpu ().numpy ()
121140
122141 def load_state_dict (
123- self : TissueDetectionModel ,
142+ self : GrandQCModel ,
124143 state_dict : Mapping [str , Any ],
125144 ** kwargs : bool ,
126145 ) -> torch .nn .modules .module ._IncompatibleKeys :
127- """Load state dict for the TissueDetectionModel ."""
146+ """Load state dict for the GrandQCModel ."""
128147 return self .tissue_detection_model .load_state_dict (state_dict , ** kwargs )
0 commit comments