Skip to content

Commit 8a7295d

Browse files
committed
fix tests
1 parent 899d6cb commit 8a7295d

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

tests/models/test_arch_grandqc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_functional_grandqc(remote_sample: Callable) -> None:
2525
assert pretrained_weights is not None
2626

2727
# test creation
28-
model = TissueDetectionModel()
28+
model = TissueDetectionModel(num_input_channels=3, num_output_channels=2)
2929
assert model is not None
3030

3131
# load pretrained weights
@@ -36,6 +36,8 @@ def test_functional_grandqc(remote_sample: Callable) -> None:
3636
model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10")
3737
assert isinstance(model, TissueDetectionModel)
3838
assert isinstance(ioconfig, IOSegmentorConfig)
39+
assert model.num_input_channels == 3
40+
assert model.num_output_channels == 2
3941

4042
# test inference
4143
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))

tiatoolbox/models/architecture/grandqc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,20 @@ class TissueDetectionModel(ModelABC):
3434
3535
"""
3636

37-
def __init__(self: TissueDetectionModel) -> None:
37+
def __init__(
38+
self: TissueDetectionModel, num_input_channels: int, num_output_channels: int
39+
) -> None:
3840
"""Initialize TissueDetectionModel."""
3941
super().__init__()
42+
self.num_input_channels = num_input_channels
43+
self.num_output_channels = num_output_channels
4044
self._postproc = self.postproc
4145
self._preproc = self.preproc
4246
self.tissue_detection_model = smp.UnetPlusPlus(
4347
encoder_name="timm-efficientnet-b0",
4448
encoder_weights=None,
45-
classes=2,
49+
in_channels=self.num_input_channels,
50+
classes=self.num_output_channels,
4651
activation=None,
4752
)
4853

0 commit comments

Comments
 (0)