Skip to content

Commit e2840c6

Browse files
committed
test: adjust tests
1 parent 511addf commit e2840c6

File tree

22 files changed

+292
-282
lines changed

22 files changed

+292
-282
lines changed

cellseg_models_pytorch/datasets/tests/__init__.py

Whitespace-only changes.

cellseg_models_pytorch/decoders/tests/test_decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from cellseg_models_pytorch.decoders import UnetDecoder
4+
from cellseg_models_pytorch.decoders.unet_decoder import UnetDecoder
55

66

77
@pytest.mark.parametrize(
Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,52 @@
1-
import pytest
2-
3-
from cellseg_models_pytorch.inference import ResizeInferer, SlidingWindowInferer
4-
from cellseg_models_pytorch.models import cellpose_plus
5-
6-
7-
@pytest.mark.parametrize("batch_size", [1, 2])
8-
def test_slidingwin_inference(img_dir, batch_size):
9-
model = cellpose_plus(sem_classes=3, type_classes=3, long_skip="unet")
10-
11-
inferer = SlidingWindowInferer(
12-
model,
13-
img_dir,
14-
out_activations={"sem": "softmax", "type": "softmax", "cellpose": "tanh"},
15-
out_boundary_weights={"sem": False, "type": False, "cellpose": True},
16-
patch_size=(256, 256),
17-
stride=256,
18-
padding=80,
19-
instance_postproc="hovernet",
20-
batch_size=batch_size,
21-
device="cpu",
22-
parallel=False,
23-
)
24-
25-
inferer.infer()
26-
27-
samples = list(inferer.out_masks.keys())
28-
assert inferer.out_masks[samples[0]]["inst"].shape == (512, 512)
29-
30-
31-
@pytest.mark.parametrize("batch_size", [1, 2])
32-
def test_resize_inference(img_dir, batch_size):
33-
model = cellpose_plus(sem_classes=3, type_classes=3, long_skip="unet")
34-
35-
inferer = ResizeInferer(
36-
model,
37-
img_dir,
38-
out_activations={"sem": "softmax", "type": "softmax", "cellpose": "tanh"},
39-
out_boundary_weights={"sem": False, "type": False, "cellpose": True},
40-
resize=(256, 256),
41-
padding=80,
42-
instance_postproc="hovernet",
43-
batch_size=batch_size,
44-
device="cpu",
45-
parallel=False,
46-
)
47-
48-
inferer.infer()
49-
50-
samples = list(inferer.out_masks.keys())
51-
assert inferer.out_masks[samples[0]]["inst"].shape == (512, 512)
1+
# import pytest
2+
3+
# from cellseg_models_pytorch.inference import Inferer, SlidingWindowInferer
4+
# from cellseg_models_pytorch.models import cellpose_plus
5+
# import numpy as np
6+
# import torch
7+
8+
# def test_slidingwin_inferer(img_sample512):
9+
# model = cellpose_plus(n_sem_classes=3, n_type_classes=3, long_skip="unet")
10+
11+
# inferer = SlidingWindowInferer(
12+
# model,
13+
# patch_shape=(256, 256),
14+
# stride=256,
15+
# out_activations={"sem-sem": "softmax", "cellpose-type": "softmax", "cellpose-cellpose": "tanh"},
16+
# out_boundary_weights={"sem-sem": False, "cellpose-type": False, "cellpose-cellpose": True},
17+
# post_proc_method="cellpose",
18+
# num_post_proc_threads=1,
19+
# mixed_precision=True,
20+
# )
21+
22+
# im = img_sample512
23+
# im = np.transpose(im, (2, 0, 1))
24+
# im = np.expand_dims(im, axis=0)
25+
# im = torch.tensor(im)
26+
# probs = inferer.predict(im.float())
27+
# out_masks = inferer.post_process(probs)
28+
29+
# assert out_masks["sem-sem"].shape == (256, 256)
30+
31+
32+
# def test_inferer(img_sample256):
33+
# model = cellpose_plus(n_sem_classes=3, n_type_classes=3, long_skip="unet")
34+
35+
# inferer = Inferer(
36+
# model,
37+
# input_shape=(256, 256),
38+
# out_activations={"sem-sem": "softmax", "cellpose-type": "softmax", "cellpose-cellpose": "tanh"},
39+
# out_boundary_weights={"sem-sem": False, "cellpose-type": False, "cellpose-cellpose": True},
40+
# post_proc_method="cellpose",
41+
# num_post_proc_threads=1,
42+
# mixed_precision=True,
43+
# )
44+
45+
# im = img_sample256
46+
# im = np.transpose(im, (2, 0, 1))
47+
# im = np.expand_dims(im, axis=0)
48+
# im = torch.tensor(im)
49+
# probs = inferer.predict(im.float())
50+
# out_masks = inferer.post_process(probs)
51+
52+
# assert out_masks["sem-sem"].shape == (256, 256)

cellseg_models_pytorch/losses/tests/test_multitask_loss.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections import OrderedDict
2-
31
import pytest
42
import torch
53

@@ -20,16 +18,14 @@
2018
"losses",
2119
[
2220
{"inst": JointLoss([TverskyLoss(), IoULoss()])},
23-
{"inst": CELoss(), "type": SSIM()},
21+
{"inst": CELoss(), "types": SSIM()},
2422
],
2523
)
2624
def test_multitask_loss(n_classes, losses):
27-
losses = OrderedDict(losses)
28-
2925
yhats = {}
3026
targets = {}
3127
for i, br in enumerate(losses.keys(), 1):
32-
yhats[f"{br}"], targets[f"{br}"] = _get_dummy_pair(n_classes + i)
28+
yhats[br], targets[br] = _get_dummy_pair(n_classes + i)
3329

3430
mtl = MultiTaskLoss(losses)
3531
mtl(yhats, targets)
Lines changed: 52 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,32 @@
11
import pytest
22
import torch
33

4-
from cellseg_models_pytorch.models import MultiTaskUnet, get_model
4+
from cellseg_models_pytorch.models import get_model
55

66

77
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
88
@pytest.mark.parametrize("model_type", ["base", "plus"])
9-
@pytest.mark.parametrize("style_channels", [None, 32])
10-
@pytest.mark.parametrize("add_stem_skip", [False, True])
11-
def test_cppnet_fwdbwd(enc_name, model_type, style_channels, add_stem_skip):
9+
def test_cppnet_fwdbwd(enc_name, model_type):
1210
n_rays = 3
1311
x = torch.rand([1, 3, 64, 64])
1412
model = get_model(
1513
name="cppnet",
1614
type=model_type,
1715
enc_name=enc_name,
1816
n_rays=n_rays,
19-
ntypes=3,
20-
ntissues=3,
21-
style_channels=style_channels,
22-
add_stem_skip=add_stem_skip,
17+
n_type_classes=3,
18+
n_sem_classes=3,
2319
enc_pretrain=False,
2420
)
2521

2622
y = model(x)
27-
y["stardist_refined"].mean().backward()
23+
y["stardist-stardist"].mean().backward()
2824

29-
assert y["type"].shape == x.shape
30-
assert y["stardist_refined"].shape == torch.Size([1, n_rays, 64, 64])
25+
assert y["type-type"].shape == x.shape
26+
assert y["stardist-stardist"].shape == torch.Size([1, n_rays, 64, 64])
3127

32-
if "sem" in y.keys():
33-
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
28+
if "sem-sem" in y.keys():
29+
assert y["sem-sem"].shape == torch.Size([1, 3, 64, 64])
3430

3531

3632
@pytest.mark.parametrize(
@@ -43,151 +39,119 @@ def test_cppnet_fwdbwd(enc_name, model_type, style_channels, add_stem_skip):
4339
],
4440
)
4541
@pytest.mark.parametrize("model_type", ["base", "plus", "small_plus", "small"])
46-
@pytest.mark.parametrize("style_channels", [None, 32])
47-
def test_cellvit_fwdbwd(enc_name, model_type, style_channels):
42+
def test_cellvit_fwdbwd(enc_name, model_type):
4843
x = torch.rand([1, 3, 32, 32])
4944
model = get_model(
5045
name="cellvit",
5146
type=model_type,
5247
enc_name=enc_name,
53-
ntypes=3,
54-
ntissues=3,
55-
style_channels=style_channels,
48+
n_type_classes=3,
49+
n_sem_classes=3,
5650
enc_pretrain=False,
51+
enc_freeze=True
5752
)
58-
model.freeze_encoder()
5953

6054
y = model(x)
61-
y["hovernet"].mean().backward()
55+
y["hovernet-hovernet"].mean().backward()
6256

63-
assert y["type"].shape == x.shape
57+
assert y["type-type"].shape == x.shape
6458

65-
if "sem" in y.keys():
66-
assert y["sem"].shape == torch.Size([1, 3, 32, 32])
59+
if "sem-sem" in y.keys():
60+
assert y["sem-sem"].shape == torch.Size([1, 3, 32, 32])
6761

6862

6963
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
7064
@pytest.mark.parametrize("model_type", ["base", "plus", "small_plus", "small"])
71-
@pytest.mark.parametrize("style_channels", [None, 32])
72-
@pytest.mark.parametrize("add_stem_skip", [False, True])
73-
def test_hovernet_fwdbwd(enc_name, model_type, style_channels, add_stem_skip):
65+
@pytest.mark.parametrize("stem_skip_kws", [None, {"short_skip": "residual"}])
66+
@pytest.mark.parametrize("style_channels", [None, 256])
67+
def test_hovernet_fwdbwd(enc_name, model_type, stem_skip_kws, style_channels):
7468
x = torch.rand([1, 3, 64, 64])
7569
model = get_model(
7670
name="hovernet",
7771
type=model_type,
7872
enc_name=enc_name,
79-
ntypes=3,
80-
ntissues=3,
81-
style_channels=style_channels,
82-
add_stem_skip=add_stem_skip,
73+
n_type_classes=3,
74+
n_sem_classes=3,
8375
enc_pretrain=False,
76+
style_channels=style_channels,
77+
stem_skip_kws=stem_skip_kws,
8478
)
8579

8680
y = model(x)
87-
y["hovernet"].mean().backward()
81+
y["hovernet-hovernet"].mean().backward()
8882

89-
assert y["type"].shape == x.shape
83+
assert y["type-type"].shape == x.shape
9084

91-
if "sem" in y.keys():
92-
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
85+
if "sem-sem" in y.keys():
86+
assert y["sem-sem"].shape == torch.Size([1, 3, 64, 64])
9387

9488

9589
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
9690
@pytest.mark.parametrize("model_type", ["base", "plus"])
97-
@pytest.mark.parametrize("style_channels", [None, 32])
98-
@pytest.mark.parametrize("add_stem_skip", [False, True])
99-
def test_stardist_fwdbwd(enc_name, model_type, style_channels, add_stem_skip):
91+
def test_stardist_fwdbwd(enc_name, model_type):
10092
n_rays = 3
10193
x = torch.rand([1, 3, 64, 64])
10294
model = get_model(
10395
name="stardist",
10496
type=model_type,
10597
n_rays=n_rays,
10698
enc_name=enc_name,
107-
ntypes=3,
108-
ntissues=3,
109-
style_channels=style_channels,
110-
add_stem_skip=add_stem_skip,
99+
n_type_classes=3,
100+
n_sem_classes=3,
111101
enc_pretrain=False,
112102
)
113103

114104
y = model(x)
115-
y["stardist"].mean().backward()
105+
y["stardist-stardist"].mean().backward()
116106

117-
assert y["type"].shape == x.shape
118-
assert y["stardist"].shape == torch.Size([1, n_rays, 64, 64])
107+
assert y["stardist-type"].shape == x.shape
108+
assert y["stardist-stardist"].shape == torch.Size([1, n_rays, 64, 64])
119109

120-
if "sem" in y.keys():
121-
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
110+
if "sem-sem" in y.keys():
111+
assert y["sem-sem"].shape == torch.Size([1, 3, 64, 64])
122112

123113

124114
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
125115
@pytest.mark.parametrize("model_type", ["base", "plus"])
126-
@pytest.mark.parametrize("add_stem_skip", [False, True])
127-
def test_cellpose_fwdbwd(enc_name, model_type, add_stem_skip):
116+
def test_cellpose_fwdbwd(enc_name, model_type):
128117
x = torch.rand([1, 3, 64, 64])
129118
model = get_model(
130119
name="cellpose",
131120
type=model_type,
132121
enc_name=enc_name,
133-
ntypes=3,
134-
ntissues=3,
135-
add_stem_skip=add_stem_skip,
122+
n_type_classes=3,
123+
n_sem_classes=3,
136124
enc_pretrain=False,
137125
)
138126

139127
y = model(x)
140-
y["cellpose"].mean().backward()
128+
y["cellpose-cellpose"].mean().backward()
141129

142-
assert y["type"].shape == x.shape
143-
assert y["cellpose"].shape == torch.Size([1, 2, 64, 64])
130+
assert y["cellpose-type"].shape == x.shape
131+
assert y["cellpose-cellpose"].shape == torch.Size([1, 2, 64, 64])
144132

145-
if "sem" in y.keys():
146-
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
133+
if "sem-sem" in y.keys():
134+
assert y["sem-sem"].shape == torch.Size([1, 3, 64, 64])
147135

148136

149137
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
150138
@pytest.mark.parametrize("model_type", ["base", "plus"])
151-
@pytest.mark.parametrize("add_stem_skip", [False, True])
152-
def test_cellpose_fwdbwd(enc_name, model_type, add_stem_skip):
139+
def test_cellpose_fwdbwd(enc_name, model_type):
153140
x = torch.rand([1, 3, 64, 64])
154141
model = get_model(
155142
name="omnipose",
156143
type=model_type,
157144
enc_name=enc_name,
158-
ntypes=3,
159-
ntissues=3,
160-
add_stem_skip=add_stem_skip,
145+
n_type_classes=3,
146+
n_sem_classes=3,
161147
enc_pretrain=False,
162148
)
163149

164150
y = model(x)
165-
y["omnipose"].mean().backward()
166-
167-
assert y["type"].shape == x.shape
168-
assert y["omnipose"].shape == torch.Size([1, 2, 64, 64])
169-
170-
if "sem" in y.keys():
171-
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
151+
y["omnipose-omnipose"].mean().backward()
172152

153+
assert y["omnipose-type"].shape == x.shape
154+
assert y["omnipose-omnipose"].shape == torch.Size([1, 2, 64, 64])
173155

174-
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
175-
@pytest.mark.parametrize("add_stem_skip", [False, True])
176-
def test_multitaskunet_fwdbwd(enc_name, add_stem_skip):
177-
x = torch.rand([1, 3, 64, 64])
178-
m = MultiTaskUnet(
179-
decoders=("sem",),
180-
heads={"sem": {"sem": 3}},
181-
n_conv_layers={"sem": (1, 1, 1, 1)},
182-
n_conv_blocks={"sem": ((2,), (2,), (2,), (2,))},
183-
out_channels={"sem": (128, 64, 32, 16)},
184-
long_skips={"sem": "unet"},
185-
dec_params={"sem": None},
186-
add_stem_skip=add_stem_skip,
187-
enc_name=enc_name,
188-
enc_pretrain=False,
189-
)
190-
y = m(x)
191-
y["sem"].mean().backward()
192-
193-
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
156+
if "sem-sem" in y.keys():
157+
assert y["sem-sem"].shape == torch.Size([1, 3, 64, 64])

0 commit comments

Comments
 (0)