11import pytest
22import torch
33
4- from cellseg_models_pytorch .models import MultiTaskUnet , get_model
4+ from cellseg_models_pytorch .models import get_model
55
66
77@pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
88@pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
9- @pytest .mark .parametrize ("style_channels" , [None , 32 ])
10- @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
11- def test_cppnet_fwdbwd (enc_name , model_type , style_channels , add_stem_skip ):
9+ def test_cppnet_fwdbwd (enc_name , model_type ):
1210 n_rays = 3
1311 x = torch .rand ([1 , 3 , 64 , 64 ])
1412 model = get_model (
1513 name = "cppnet" ,
1614 type = model_type ,
1715 enc_name = enc_name ,
1816 n_rays = n_rays ,
19- ntypes = 3 ,
20- ntissues = 3 ,
21- style_channels = style_channels ,
22- add_stem_skip = add_stem_skip ,
17+ n_type_classes = 3 ,
18+ n_sem_classes = 3 ,
2319 enc_pretrain = False ,
2420 )
2521
2622 y = model (x )
27- y ["stardist_refined " ].mean ().backward ()
23+ y ["stardist-stardist " ].mean ().backward ()
2824
29- assert y ["type" ].shape == x .shape
30- assert y ["stardist_refined " ].shape == torch .Size ([1 , n_rays , 64 , 64 ])
25+ assert y ["type-type " ].shape == x .shape
26+ assert y ["stardist-stardist " ].shape == torch .Size ([1 , n_rays , 64 , 64 ])
3127
32- if "sem" in y .keys ():
33- assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
28+ if "sem-sem " in y .keys ():
29+ assert y ["sem-sem " ].shape == torch .Size ([1 , 3 , 64 , 64 ])
3430
3531
3632@pytest .mark .parametrize (
@@ -43,151 +39,119 @@ def test_cppnet_fwdbwd(enc_name, model_type, style_channels, add_stem_skip):
4339 ],
4440)
4541@pytest .mark .parametrize ("model_type" , ["base" , "plus" , "small_plus" , "small" ])
46- @pytest .mark .parametrize ("style_channels" , [None , 32 ])
47- def test_cellvit_fwdbwd (enc_name , model_type , style_channels ):
42+ def test_cellvit_fwdbwd (enc_name , model_type ):
4843 x = torch .rand ([1 , 3 , 32 , 32 ])
4944 model = get_model (
5045 name = "cellvit" ,
5146 type = model_type ,
5247 enc_name = enc_name ,
53- ntypes = 3 ,
54- ntissues = 3 ,
55- style_channels = style_channels ,
48+ n_type_classes = 3 ,
49+ n_sem_classes = 3 ,
5650 enc_pretrain = False ,
51+ enc_freeze = True
5752 )
58- model .freeze_encoder ()
5953
6054 y = model (x )
61- y ["hovernet" ].mean ().backward ()
55+ y ["hovernet-hovernet " ].mean ().backward ()
6256
63- assert y ["type" ].shape == x .shape
57+ assert y ["type-type " ].shape == x .shape
6458
65- if "sem" in y .keys ():
66- assert y ["sem" ].shape == torch .Size ([1 , 3 , 32 , 32 ])
59+ if "sem-sem " in y .keys ():
60+ assert y ["sem-sem " ].shape == torch .Size ([1 , 3 , 32 , 32 ])
6761
6862
6963@pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
7064@pytest .mark .parametrize ("model_type" , ["base" , "plus" , "small_plus" , "small" ])
71- @pytest .mark .parametrize ("style_channels " , [None , 32 ])
72- @pytest .mark .parametrize ("add_stem_skip " , [False , True ])
73- def test_hovernet_fwdbwd (enc_name , model_type , style_channels , add_stem_skip ):
65+ @pytest .mark .parametrize ("stem_skip_kws " , [None , { "short_skip" : "residual" } ])
66+ @pytest .mark .parametrize ("style_channels " , [None , 256 ])
67+ def test_hovernet_fwdbwd (enc_name , model_type , stem_skip_kws , style_channels ):
7468 x = torch .rand ([1 , 3 , 64 , 64 ])
7569 model = get_model (
7670 name = "hovernet" ,
7771 type = model_type ,
7872 enc_name = enc_name ,
79- ntypes = 3 ,
80- ntissues = 3 ,
81- style_channels = style_channels ,
82- add_stem_skip = add_stem_skip ,
73+ n_type_classes = 3 ,
74+ n_sem_classes = 3 ,
8375 enc_pretrain = False ,
76+ style_channels = style_channels ,
77+ stem_skip_kws = stem_skip_kws ,
8478 )
8579
8680 y = model (x )
87- y ["hovernet" ].mean ().backward ()
81+ y ["hovernet-hovernet " ].mean ().backward ()
8882
89- assert y ["type" ].shape == x .shape
83+ assert y ["type-type " ].shape == x .shape
9084
91- if "sem" in y .keys ():
92- assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
85+ if "sem-sem " in y .keys ():
86+ assert y ["sem-sem " ].shape == torch .Size ([1 , 3 , 64 , 64 ])
9387
9488
9589@pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
9690@pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
97- @pytest .mark .parametrize ("style_channels" , [None , 32 ])
98- @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
99- def test_stardist_fwdbwd (enc_name , model_type , style_channels , add_stem_skip ):
91+ def test_stardist_fwdbwd (enc_name , model_type ):
10092 n_rays = 3
10193 x = torch .rand ([1 , 3 , 64 , 64 ])
10294 model = get_model (
10395 name = "stardist" ,
10496 type = model_type ,
10597 n_rays = n_rays ,
10698 enc_name = enc_name ,
107- ntypes = 3 ,
108- ntissues = 3 ,
109- style_channels = style_channels ,
110- add_stem_skip = add_stem_skip ,
99+ n_type_classes = 3 ,
100+ n_sem_classes = 3 ,
111101 enc_pretrain = False ,
112102 )
113103
114104 y = model (x )
115- y ["stardist" ].mean ().backward ()
105+ y ["stardist-stardist " ].mean ().backward ()
116106
117- assert y ["type" ].shape == x .shape
118- assert y ["stardist" ].shape == torch .Size ([1 , n_rays , 64 , 64 ])
107+ assert y ["stardist- type" ].shape == x .shape
108+ assert y ["stardist-stardist " ].shape == torch .Size ([1 , n_rays , 64 , 64 ])
119109
120- if "sem" in y .keys ():
121- assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
110+ if "sem-sem " in y .keys ():
111+ assert y ["sem-sem " ].shape == torch .Size ([1 , 3 , 64 , 64 ])
122112
123113
124114@pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
125115@pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
126- @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
127- def test_cellpose_fwdbwd (enc_name , model_type , add_stem_skip ):
116+ def test_cellpose_fwdbwd (enc_name , model_type ):
128117 x = torch .rand ([1 , 3 , 64 , 64 ])
129118 model = get_model (
130119 name = "cellpose" ,
131120 type = model_type ,
132121 enc_name = enc_name ,
133- ntypes = 3 ,
134- ntissues = 3 ,
135- add_stem_skip = add_stem_skip ,
122+ n_type_classes = 3 ,
123+ n_sem_classes = 3 ,
136124 enc_pretrain = False ,
137125 )
138126
139127 y = model (x )
140- y ["cellpose" ].mean ().backward ()
128+ y ["cellpose-cellpose " ].mean ().backward ()
141129
142- assert y ["type" ].shape == x .shape
143- assert y ["cellpose" ].shape == torch .Size ([1 , 2 , 64 , 64 ])
130+ assert y ["cellpose- type" ].shape == x .shape
131+ assert y ["cellpose-cellpose " ].shape == torch .Size ([1 , 2 , 64 , 64 ])
144132
145- if "sem" in y .keys ():
146- assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
133+ if "sem-sem " in y .keys ():
134+ assert y ["sem-sem " ].shape == torch .Size ([1 , 3 , 64 , 64 ])
147135
148136
149137@pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
150138@pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
151- @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
152- def test_cellpose_fwdbwd (enc_name , model_type , add_stem_skip ):
139+ def test_cellpose_fwdbwd (enc_name , model_type ):
153140 x = torch .rand ([1 , 3 , 64 , 64 ])
154141 model = get_model (
155142 name = "omnipose" ,
156143 type = model_type ,
157144 enc_name = enc_name ,
158- ntypes = 3 ,
159- ntissues = 3 ,
160- add_stem_skip = add_stem_skip ,
145+ n_type_classes = 3 ,
146+ n_sem_classes = 3 ,
161147 enc_pretrain = False ,
162148 )
163149
164150 y = model (x )
165- y ["omnipose" ].mean ().backward ()
166-
167- assert y ["type" ].shape == x .shape
168- assert y ["omnipose" ].shape == torch .Size ([1 , 2 , 64 , 64 ])
169-
170- if "sem" in y .keys ():
171- assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
151+ y ["omnipose-omnipose" ].mean ().backward ()
172152
153+ assert y ["omnipose-type" ].shape == x .shape
154+ assert y ["omnipose-omnipose" ].shape == torch .Size ([1 , 2 , 64 , 64 ])
173155
174- @pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
175- @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
176- def test_multitaskunet_fwdbwd (enc_name , add_stem_skip ):
177- x = torch .rand ([1 , 3 , 64 , 64 ])
178- m = MultiTaskUnet (
179- decoders = ("sem" ,),
180- heads = {"sem" : {"sem" : 3 }},
181- n_conv_layers = {"sem" : (1 , 1 , 1 , 1 )},
182- n_conv_blocks = {"sem" : ((2 ,), (2 ,), (2 ,), (2 ,))},
183- out_channels = {"sem" : (128 , 64 , 32 , 16 )},
184- long_skips = {"sem" : "unet" },
185- dec_params = {"sem" : None },
186- add_stem_skip = add_stem_skip ,
187- enc_name = enc_name ,
188- enc_pretrain = False ,
189- )
190- y = m (x )
191- y ["sem" ].mean ().backward ()
192-
193- assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
156+ if "sem-sem" in y .keys ():
157+ assert y ["sem-sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
0 commit comments