File tree Expand file tree Collapse file tree 2 files changed +10
-3
lines changed
tiatoolbox/models/architecture Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -25,7 +25,7 @@ def test_functional_grandqc(remote_sample: Callable) -> None:
2525 assert pretrained_weights is not None
2626
2727 # test creation
28- model = TissueDetectionModel ()
28+ model = TissueDetectionModel (num_input_channels = 3 , num_output_channels = 2 )
2929 assert model is not None
3030
3131 # load pretrained weights
@@ -36,6 +36,8 @@ def test_functional_grandqc(remote_sample: Callable) -> None:
3636 model , ioconfig = get_pretrained_model ("grandqc_tissue_detection_mpp10" )
3737 assert isinstance (model , TissueDetectionModel )
3838 assert isinstance (ioconfig , IOSegmentorConfig )
39+ assert model .num_input_channels == 3
40+ assert model .num_output_channels == 2
3941
4042 # test inference
4143 mini_wsi_svs = Path (remote_sample ("wsi2_4k_4k_svs" ))
Original file line number Diff line number Diff line change @@ -34,15 +34,20 @@ class TissueDetectionModel(ModelABC):
3434
3535 """
3636
37- def __init__ (self : TissueDetectionModel ) -> None :
37+ def __init__ (
38+ self : TissueDetectionModel , num_input_channels : int , num_output_channels : int
39+ ) -> None :
3840 """Initialize TissueDetectionModel."""
3941 super ().__init__ ()
42+ self .num_input_channels = num_input_channels
43+ self .num_output_channels = num_output_channels
4044 self ._postproc = self .postproc
4145 self ._preproc = self .preproc
4246 self .tissue_detection_model = smp .UnetPlusPlus (
4347 encoder_name = "timm-efficientnet-b0" ,
4448 encoder_weights = None ,
45- classes = 2 ,
49+ in_channels = self .num_input_channels ,
50+ classes = self .num_output_channels ,
4651 activation = None ,
4752 )
4853
You can’t perform that action at this time.
0 commit comments