Skip to content

Commit a75db82

Browse files
committed
fix pytest mock model to be same as cp model. Rename tests accordingly, add slow test to check full model with 2d grayscale img
1 parent 1575f33 commit a75db82

File tree

3 files changed

+66
-49
lines changed

3 files changed

+66
-49
lines changed

conftest.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def data_dir(image_names):
5656

5757

5858
@pytest.fixture()
59-
def cellposemodel_fixture_2D():
59+
def cellposemodel_fixture_24layer():
6060
""" This is functionally identical to CellposeModel but uses mock class """
6161
use_gpu = torch.cuda.is_available()
6262
use_mps = 'mps' if torch.backends.mps.is_available() else False
@@ -66,10 +66,12 @@ def cellposemodel_fixture_2D():
6666

6767

6868
@pytest.fixture()
69-
def cellposemodel_fixture_3D():
69+
def cellposemodel_fixture_2layer():
7070
""" This is only uses 2 transformer blocks for speed """
71-
use_gpu = torch.cuda.is_available() # Turn of gpu for mac 3d
72-
model = MockCellposeModel(2, gpu=use_gpu)
71+
use_gpu = torch.cuda.is_available()
72+
use_mps = 'mps' if torch.backends.mps.is_available() else False
73+
gpu = use_gpu or use_mps
74+
model = MockCellposeModel(n_keep_layers=2, gpu=gpu)
7375
yield model
7476

7577

@@ -107,11 +109,12 @@ def __init__(self, n_keep_layers=2, gpu=False):
107109
super().__init__(gpu=gpu)
108110

109111
self.net = MockTransformer(n_keep_layers)
112+
self.net.to(self.device)
110113
self.net.load_model(Path().home() / '.cellpose/models/cpsam', device=self.device)
111114

112-
def eval(self, x, **kwargs):
115+
def eval(self, *args, **kwargs):
113116
tic = time.time()
114-
res = super().eval(x, **kwargs)
117+
res = super().eval(*args, **kwargs)
115118
toc = time.time()
116119

117120
print(f'eval() time elapsed: {toc-tic}')

tests/test_output.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from cellpose import io, metrics, utils
1+
from cellpose import io, metrics, utils, models
22
import pytest
33
from subprocess import check_output, STDOUT
44
import os
@@ -38,38 +38,53 @@ def clear_output(data_dir, image_names):
3838
os.remove(npy_output)
3939

4040

41-
42-
def test_class_2D(data_dir, image_names, cellposemodel_fixture_2D):
41+
def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer):
4342
clear_output(data_dir, image_names)
43+
44+
img_file = data_dir / '2D' / image_names[0]
45+
46+
img = io.imread_2D(img_file)
47+
# flowps = io.imread(img_file.parent / (img_file.stem + "_cp4_gt_flowps.tif"))
48+
49+
masks_pred, _, _ = cellposemodel_fixture_24layer.eval(img, normalize=True)
50+
io.imsave(data_dir / '2D' / (img_file.stem + "_cp_masks.png"), masks_pred)
51+
# flowsp_pred = np.concatenate([flows_pred[1], flows_pred[2][None, ...]], axis=0)
52+
# mse = np.sqrt((flowsp_pred - flowps) ** 2).sum()
53+
# assert mse.sum() < 1e-8, f"MSE of flows is too high: {mse.sum()} on image {image_name}"
54+
# print("MSE of flows is %f" % mse.mean())
55+
56+
compare_masks_cp4(data_dir, image_names[0], "2D")
57+
# clear_output(data_dir, image_names)
58+
4459

60+
@pytest.mark.slow
61+
def test_class_2D_all_imgs(data_dir, image_names, cellposemodel_fixture_24layer):
62+
clear_output(data_dir, image_names)
4563
for image_name in image_names:
46-
4764
img_file = data_dir / '2D' / image_name
4865

4966
img = io.imread_2D(img_file)
5067
# flowps = io.imread(img_file.parent / (img_file.stem + "_cp4_gt_flowps.tif"))
5168

52-
masks_pred, _, _ = cellposemodel_fixture_2D.eval(img, normalize=True)
69+
masks_pred, _, _ = cellposemodel_fixture_24layer.eval(img, normalize=True)
5370
io.imsave(data_dir / '2D' / (img_file.stem + "_cp_masks.png"), masks_pred)
5471
# flowsp_pred = np.concatenate([flows_pred[1], flows_pred[2][None, ...]], axis=0)
5572
# mse = np.sqrt((flowsp_pred - flowps) ** 2).sum()
5673
# assert mse.sum() < 1e-8, f"MSE of flows is too high: {mse.sum()} on image {image_name}"
5774
# print("MSE of flows is %f" % mse.mean())
5875

59-
break # Just test one image for now
60-
61-
compare_masks_cp4(data_dir, image_names[0], "2D")
76+
compare_masks_cp4(data_dir, image_names, "2D")
6277
clear_output(data_dir, image_names)
6378

6479

6580
@pytest.mark.slow
66-
def test_cyto2_to_seg(data_dir, image_names, cellposemodel_fixture_2D):
81+
def test_cyto2_to_seg(data_dir, image_names, cellposemodel_fixture_24layer):
6782
clear_output(data_dir, image_names)
6883
file_names = [data_dir / "2D" / n for n in image_names]
6984
imgs = [io.imread_2D(file_name) for file_name in file_names]
7085

7186
# masks, flows, styles = model.eval(imgs, diameter=30) # Errors during SAM stuff
72-
masks, flows, _ = cellposemodel_fixture_2D.eval(imgs, bsize=256, batch_size=64, normalize=True)
87+
masks, flows, _ = cellposemodel_fixture_24layer.eval(imgs, bsize=256, batch_size=64, normalize=True)
7388

7489
for file_name, mask in zip(file_names, masks):
7590
io.imsave(data_dir/'2D'/(file_name.stem + '_cp_masks.png'), mask)
@@ -79,13 +94,13 @@ def test_cyto2_to_seg(data_dir, image_names, cellposemodel_fixture_2D):
7994
clear_output(data_dir, image_names)
8095

8196

82-
def test_class_3D(data_dir, image_names_3d, cellposemodel_fixture_3D):
97+
def test_class_3D(data_dir, image_names_3d, cellposemodel_fixture_2layer):
8398
clear_output(data_dir, image_names_3d)
8499

85100
for image_name in image_names_3d:
86101
img_file = data_dir / '3D' / image_name
87102
img = io.imread_3D(img_file)
88-
masks_pred, flows_pred, _ = cellposemodel_fixture_3D.eval(img, do_3D=True, channel_axis=-1, z_axis=0)
103+
masks_pred, flows_pred, _ = cellposemodel_fixture_2layer.eval(img, do_3D=True, channel_axis=-1, z_axis=0)
89104
# io.imsave(data_dir / "3D" / (img_file.stem + "_cp_masks.tif"), masks)
90105

91106
assert img.shape[:-1] == masks_pred.shape, f'mask incorrect shape for {image_name}, {masks_pred.shape=}'
@@ -129,15 +144,15 @@ def test_cli_3D_diam(data_dir, image_names_3d):
129144

130145

131146
@pytest.mark.slow
132-
def test_outlines_list(data_dir, image_names, cellposemodel_fixture_2D):
147+
def test_outlines_list(data_dir, image_names, cellposemodel_fixture_24layer):
133148
""" test both single and multithreaded by comparing them"""
134149
clear_output(data_dir, image_names)
135150
image_name = "rgb_2D.png"
136151

137152
file_name = str(data_dir.joinpath("2D").joinpath(image_name))
138153
img = io.imread(file_name)
139154

140-
masks, _, _ = cellposemodel_fixture_2D.eval(img, diameter=30)
155+
masks, _, _ = cellposemodel_fixture_24layer.eval(img, diameter=30)
141156
outlines_single = utils.outlines_list(masks, multiprocessing=False)
142157
outlines_multi = utils.outlines_list(masks, multiprocessing=True)
143158

tests/test_shape.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,56 +3,55 @@
33

44

55
#################### 2D Tests ####################
6-
@pytest.mark.slow
7-
def test_shape_2D_grayscale(cellposemodel_fixture_2D):
6+
def test_shape_2D_grayscale(cellposemodel_fixture_24layer):
87
img = np.zeros((224, 224))
9-
masks, _, _ = cellposemodel_fixture_2D.eval(img)
8+
masks, _, _ = cellposemodel_fixture_24layer.eval(img)
109
assert masks.shape == (224, 224)
1110

1211

13-
def test_shape_2D_chan_first_diam_resize(cellposemodel_fixture_2D):
12+
def test_shape_2D_chan_first_diam_resize(cellposemodel_fixture_24layer):
1413
img = np.zeros((1, 224, 224))
15-
masks, flows, _ = cellposemodel_fixture_2D.eval(img, diameter=50)
14+
masks, flows, _ = cellposemodel_fixture_24layer.eval(img, diameter=50)
1615
assert masks.shape == (224, 224), 'mask shape mismatch'
1716
assert flows[1].shape == (2, 224, 224), 'dP shape mismatch'
1817
assert flows[2].shape == (224, 224), 'cellprob shape mismatch'
1918

2019

2120
@pytest.mark.slow
22-
def test_shape_2D_chan_diam_resize(cellposemodel_fixture_2D):
21+
def test_shape_2D_chan_diam_resize(cellposemodel_fixture_24layer):
2322
img = np.zeros((1, 224, 224))
24-
masks, _, _ = cellposemodel_fixture_2D.eval(img, diameter=50)
23+
masks, _, _ = cellposemodel_fixture_24layer.eval(img, diameter=50)
2524
assert masks.shape == (224, 224)
2625

2726

28-
def test_shape_2D_chan_last(cellposemodel_fixture_2D):
27+
def test_shape_2D_chan_last(cellposemodel_fixture_24layer):
2928
img = np.zeros((224, 224, 2))
30-
masks, flows, _ = cellposemodel_fixture_2D.eval(img)
29+
masks, flows, _ = cellposemodel_fixture_24layer.eval(img)
3130
assert masks.shape == (224, 224), 'mask shape mismatch'
3231
assert flows[1].shape == (2, 224, 224), 'dP shape mismatch'
3332
assert flows[2].shape == (224, 224), 'cellprob shape mismatch'
3433

3534

3635

3736
@pytest.mark.slow
38-
def test_shape_2D_chan_specify(cellposemodel_fixture_2D):
37+
def test_shape_2D_chan_specify(cellposemodel_fixture_24layer):
3938
img = np.zeros((224, 224, 2))
40-
masks, _, _ = cellposemodel_fixture_2D.eval(img, channel_axis=-1)
39+
masks, _, _ = cellposemodel_fixture_24layer.eval(img, channel_axis=-1)
4140
assert masks.shape == (224, 224)
4241

4342

44-
def test_shape_2D_2chan_specify(cellposemodel_fixture_2D):
43+
def test_shape_2D_2chan_specify(cellposemodel_fixture_24layer):
4544
img = np.zeros((224, 5, 224))
46-
masks, flows, _ = cellposemodel_fixture_2D.eval(img, channels=[2, 1], channel_axis=1)
45+
masks, flows, _ = cellposemodel_fixture_24layer.eval(img, channels=[2, 1], channel_axis=1)
4746
assert masks.shape == (224, 224), 'mask shape mismatch'
4847
assert flows[1].shape == (2, 224, 224), 'dP shape mismatch'
4948
assert flows[2].shape == (224, 224), 'cellprob shape mismatch'
5049

5150

5251
#################### 3D Tests ####################
53-
def test_shape_stitch(cellposemodel_fixture_3D):
52+
def test_shape_stitch(cellposemodel_fixture_2layer):
5453
img = np.zeros((5, 80, 80, 2)) # 5 layer 3d input, 2 channels
55-
masks, flows, _ = cellposemodel_fixture_3D.eval(img, channels=[0, 0],
54+
masks, flows, _ = cellposemodel_fixture_2layer.eval(img, channels=[0, 0],
5655
stitch_threshold=0.9,
5756
channel_axis=3, z_axis=0,
5857
do_3D=False)
@@ -63,52 +62,52 @@ def test_shape_stitch(cellposemodel_fixture_3D):
6362

6463

6564
@pytest.mark.slow
66-
def test_shape_3D(cellposemodel_fixture_3D):
65+
def test_shape_3D(cellposemodel_fixture_2layer):
6766
img = np.zeros((80, 80, 5, 1))
68-
masks, _, _ = cellposemodel_fixture_3D.eval(img, channel_axis=3, z_axis=2, do_3D=True)
67+
masks, _, _ = cellposemodel_fixture_2layer.eval(img, channel_axis=3, z_axis=2, do_3D=True)
6968
assert masks.shape == (5, 80, 80)
7069

7170

7271
@pytest.mark.slow
73-
def test_shape_3D_1ch(cellposemodel_fixture_3D):
72+
def test_shape_3D_1ch(cellposemodel_fixture_2layer):
7473
img = np.zeros((5, 80, 80, 1))
75-
masks, _, _ = cellposemodel_fixture_3D.eval(img, channel_axis=3, z_axis=0, do_3D=True)
74+
masks, _, _ = cellposemodel_fixture_2layer.eval(img, channel_axis=3, z_axis=0, do_3D=True)
7675
assert masks.shape == (5, 80, 80)
7776

7877

7978
@pytest.mark.slow
80-
def test_shape_3D_1ch_3ndim(cellposemodel_fixture_3D):
79+
def test_shape_3D_1ch_3ndim(cellposemodel_fixture_2layer):
8180
img = np.zeros((5, 80, 80))
82-
masks, _, _ = cellposemodel_fixture_3D.eval(img, channel_axis=None, z_axis=0, do_3D=True)
81+
masks, _, _ = cellposemodel_fixture_2layer.eval(img, channel_axis=None, z_axis=0, do_3D=True)
8382
assert masks.shape == (5, 80, 80)
8483

8584

8685
@pytest.mark.slow
87-
def test_shape_3D_1ch_3ndim_diam(cellposemodel_fixture_3D):
86+
def test_shape_3D_1ch_3ndim_diam(cellposemodel_fixture_2layer):
8887
img = np.zeros((5, 80, 80))
89-
masks, _, _ = cellposemodel_fixture_3D.eval(img, channel_axis=None, diameter=50, z_axis=0, do_3D=True)
88+
masks, _, _ = cellposemodel_fixture_2layer.eval(img, channel_axis=None, diameter=50, z_axis=0, do_3D=True)
9089
assert masks.shape == (5, 80, 80)
9190

9291

93-
def test_shape_3D_2ch(cellposemodel_fixture_3D):
92+
def test_shape_3D_2ch(cellposemodel_fixture_2layer):
9493
img = np.zeros((80, 2, 80, 4))
9594

96-
masks, flows, _ = cellposemodel_fixture_3D.eval(img, z_axis=-1, channel_axis=1, do_3D=True)
95+
masks, flows, _ = cellposemodel_fixture_2layer.eval(img, z_axis=-1, channel_axis=1, do_3D=True)
9796
assert masks.shape == (4, 80, 80), 'mask shape mismatch'
9897
assert flows[1].shape == (3, 4, 80, 80), 'dP shape mismatch'
9998
assert flows[2].shape == (4, 80, 80), 'cellprob shape mismatch'
10099

101100

102101
@pytest.mark.slow
103-
def test_shape_3D_rgb_diam(cellposemodel_fixture_3D):
102+
def test_shape_3D_rgb_diam(cellposemodel_fixture_2layer):
104103
img = np.zeros((5, 80, 80, 3))
105-
masks, _, _ = cellposemodel_fixture_3D.eval(img, diameter=50, channels=[0, 0],
104+
masks, _, _ = cellposemodel_fixture_2layer.eval(img, diameter=50, channels=[0, 0],
106105
channel_axis=3, z_axis=0, do_3D=True)
107106
assert masks.shape == (5, 80, 80)
108107

109108
@pytest.mark.slow
110-
def test_shape_3D_rgb(cellposemodel_fixture_3D):
109+
def test_shape_3D_rgb(cellposemodel_fixture_2layer):
111110
img = np.zeros((5, 80, 80, 3))
112-
masks, _, _ = cellposemodel_fixture_3D.eval(img, channels=[0, 0],
111+
masks, _, _ = cellposemodel_fixture_2layer.eval(img, channels=[0, 0],
113112
channel_axis=3, z_axis=0, do_3D=True)
114113
assert masks.shape == (5, 80, 80)

0 commit comments

Comments
 (0)