Skip to content

Commit 785d339

Browse files
authored
Merge pull request #1160 from MouseLand/gui_fixes_2
Gui fixes 2
2 parents 7000f95 + ad2c6f3 commit 785d339

File tree

7 files changed

+123
-64
lines changed

7 files changed

+123
-64
lines changed

cellpose/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from cellpose.version import version, version_str
2-
message = '''\n
3-
Welcome to CellposeSAM, cellpose v4.0.1! The neural network component of
2+
message = f'''\n
3+
Welcome to CellposeSAM, cellpose v{version_str}! The neural network component of
44
CPSAM is much larger than in previous versions and CPU excution is slow.
55
We encourage users to use GPU/MPS if available. \n\n'''
66
print(message)

cellpose/gui/gui.py

Lines changed: 76 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def __init__(self, image=None, logger=None):
273273
self.ratio = 1.
274274
self.reset()
275275

276+
# This needs to go after .reset() is called to get state fully set up:
277+
self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked)
278+
276279
self.load_3D = False
277280

278281
# if called with image, load it
@@ -522,8 +525,11 @@ def make_buttons(self):
522525
self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings)
523526
self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9)
524527

525-
# connect edits to the diameter box to resizing the image:
528+
# connect edits to image processing steps:
526529
self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale)
530+
self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob)
531+
self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob)
532+
self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob)
527533

528534
# Needed to do this for the drop down to not be open on startup
529535
self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True)
@@ -951,12 +957,15 @@ def reset(self):
951957
self.opacity = 128 # how opaque masks should be
952958
self.outcolor = [200, 200, 255, 200]
953959
self.NZ, self.Ly, self.Lx = 1, 256, 256
954-
self.saturation = []
955-
for r in range(3):
956-
self.saturation.append([[0, 255] for n in range(self.NZ)])
957-
self.sliders[r].setValue([0, 255])
958-
self.sliders[r].setEnabled(False)
959-
self.sliders[r].show()
960+
self.saturation = self.saturation if hasattr(self, 'saturation') else []
961+
962+
# only adjust the saturation if auto-adjust is on:
963+
if self.autobtn.isChecked():
964+
for r in range(3):
965+
self.saturation.append([[0, 255] for n in range(self.NZ)])
966+
self.sliders[r].setValue([0, 255])
967+
self.sliders[r].setEnabled(False)
968+
self.sliders[r].show()
960969
self.currentZ = 0
961970
self.flows = [[], [], [], [], [[]]]
962971
# masks matrix
@@ -1655,6 +1664,10 @@ def get_normalize_params(self):
16551664
normalize_params = {**normalize_default, **normalize_params}
16561665

16571666
return normalize_params
1667+
1668+
def compute_saturation_if_checked(self):
1669+
if self.autobtn.isChecked():
1670+
self.compute_saturation()
16581671

16591672
def compute_saturation(self, return_img=False):
16601673
norm = self.get_normalize_params()
@@ -1704,42 +1717,43 @@ def compute_saturation(self, return_img=False):
17041717
else:
17051718
img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered
17061719

1707-
self.saturation = []
1708-
for c in range(img_norm.shape[-1]):
1709-
self.saturation.append([])
1710-
if np.ptp(img_norm[..., c]) > 1e-3:
1711-
if norm3D:
1712-
x01 = np.percentile(img_norm[..., c], percentile[0])
1713-
x99 = np.percentile(img_norm[..., c], percentile[1])
1714-
if invert:
1715-
x01i = 255. - x99
1716-
x99i = 255. - x01
1717-
x01, x99 = x01i, x99i
1718-
for n in range(self.NZ):
1719-
self.saturation[-1].append([x01, x99])
1720-
else:
1721-
for z in range(self.NZ):
1722-
if self.NZ > 1:
1723-
x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
1724-
x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
1725-
else:
1726-
x01 = np.percentile(img_norm[..., c], percentile[0])
1727-
x99 = np.percentile(img_norm[..., c], percentile[1])
1720+
if self.autobtn.isChecked():
1721+
self.saturation = []
1722+
for c in range(img_norm.shape[-1]):
1723+
self.saturation.append([])
1724+
if np.ptp(img_norm[..., c]) > 1e-3:
1725+
if norm3D:
1726+
x01 = np.percentile(img_norm[..., c], percentile[0])
1727+
x99 = np.percentile(img_norm[..., c], percentile[1])
17281728
if invert:
17291729
x01i = 255. - x99
17301730
x99i = 255. - x01
17311731
x01, x99 = x01i, x99i
1732-
self.saturation[-1].append([x01, x99])
1733-
else:
1734-
for n in range(self.NZ):
1735-
self.saturation[-1].append([0, 255.])
1736-
print(self.saturation[2][self.currentZ])
1732+
for n in range(self.NZ):
1733+
self.saturation[-1].append([x01, x99])
1734+
else:
1735+
for z in range(self.NZ):
1736+
if self.NZ > 1:
1737+
x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
1738+
x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
1739+
else:
1740+
x01 = np.percentile(img_norm[..., c], percentile[0])
1741+
x99 = np.percentile(img_norm[..., c], percentile[1])
1742+
if invert:
1743+
x01i = 255. - x99
1744+
x99i = 255. - x01
1745+
x01, x99 = x01i, x99i
1746+
self.saturation[-1].append([x01, x99])
1747+
else:
1748+
for n in range(self.NZ):
1749+
self.saturation[-1].append([0, 255.])
1750+
print(self.saturation[2][self.currentZ])
17371751

1738-
if img_norm.shape[-1] == 1:
1739-
self.saturation.append(self.saturation[0])
1740-
self.saturation.append(self.saturation[0])
1752+
if img_norm.shape[-1] == 1:
1753+
self.saturation.append(self.saturation[0])
1754+
self.saturation.append(self.saturation[0])
17411755

1742-
self.autobtn.setChecked(True)
1756+
# self.autobtn.setChecked(True)
17431757
self.update_plot()
17441758

17451759

@@ -1838,14 +1852,31 @@ def compute_cprob(self):
18381852
if self.recompute_masks:
18391853
flow_threshold = self.segmentation_settings.flow_threshold
18401854
cellprob_threshold = self.segmentation_settings.cellprob_threshold
1855+
niter = self.segmentation_settings.niter
1856+
min_size = int(self.min_size.text()) if not isinstance(
1857+
self.min_size, int) else self.min_size
1858+
18411859
self.logger.info(
18421860
"computing masks with cell prob=%0.3f, flow error threshold=%0.3f" %
18431861
(cellprob_threshold, flow_threshold))
1862+
1863+
try:
1864+
dP = self.flows[2].squeeze()
1865+
cellprob = self.flows[3].squeeze()
1866+
except IndexError:
1867+
self.logger.error("Flows don't exist, try running model again.")
1868+
return
1869+
18441870
maski = dynamics.resize_and_compute_masks(
1845-
self.flows[4][:-1], self.flows[4][-1], p=self.flows[3].copy(),
1846-
cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold,
1847-
resize=self.cellpix.shape[-2:])[0]
1848-
1871+
dP=dP,
1872+
cellprob=cellprob,
1873+
niter=niter,
1874+
do_3D=self.load_3D,
1875+
min_size=min_size,
1876+
# max_size_fraction=min_size_fraction, # Leave as default
1877+
cellprob_threshold=cellprob_threshold,
1878+
flow_threshold=flow_threshold)
1879+
18491880
self.masksOn = True
18501881
if not self.OCheckBox.isChecked():
18511882
self.MCheckBox.setChecked(True)
@@ -1911,6 +1942,9 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
19111942
flows_new.append(flows[0].copy()) # RGB flow
19121943
flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) *
19131944
255).astype("uint8")) # cellprob
1945+
flows_new.append(flows[1].copy()) # XY flows
1946+
flows_new.append(flows[2].copy()) # original cellprob
1947+
19141948
if self.load_3D:
19151949
if stitch_threshold == 0.:
19161950
flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
@@ -1963,7 +1997,7 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
19631997
self.masksOn = True
19641998
self.MCheckBox.setChecked(True)
19651999
self.progress.setValue(100)
1966-
if self.restore != "filter" and self.restore is not None:
2000+
if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked():
19672001
self.compute_saturation()
19682002
if not do_3D and not stitch_threshold > 0:
19692003
self.recompute_masks = True

cellpose/gui/guiparts.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@ def __init__(self, font):
272272
grid_layout.addWidget(flow_threshold_qlabel, row, 0, 1, 2)
273273
self.flow_threshold_box = QLineEdit()
274274
self.flow_threshold_box.setText("0.4")
275-
# self.flow_threshold_box.returnPressed.connect(self.compute_cprob) # TODO
276275
self.flow_threshold_box.setFixedWidth(40)
277276
self.flow_threshold_box.setFont(font)
278277
grid_layout.addWidget(self.flow_threshold_box, row, 2, 1, 2)
@@ -286,7 +285,6 @@ def __init__(self, font):
286285
grid_layout.addWidget(cellprob_qlabel, row, 4, 1, 2)
287286
self.cellprob_threshold_box = QLineEdit()
288287
self.cellprob_threshold_box.setText("0.0")
289-
# self.cellprob_threshold.returnPressed.connect(self.compute_cprob) # TODO
290288
self.cellprob_threshold_box.setFixedWidth(40)
291289
self.cellprob_threshold_box.setFont(font)
292290
self.cellprob_threshold_box.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
@@ -408,7 +406,12 @@ def cellprob_threshold(self):
408406

409407
@property
410408
def niter(self):
411-
return int(self.niter_box.text())
409+
num = int(self.niter_box.text())
410+
if num < 1:
411+
self.niter_box.setText('200')
412+
return 200
413+
else:
414+
return num
412415

413416

414417

cellpose/gui/io.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,13 @@ def _initialize_images(parent, image, load_3D=False):
211211
"GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)"
212212
)
213213
parent.compute_saturation()
214-
elif len(parent.saturation) != parent.NZ:
215-
parent.saturation = []
216-
for r in range(3):
217-
parent.saturation.append([])
218-
for n in range(parent.NZ):
219-
parent.saturation[-1].append([0, 255])
220-
parent.sliders[r].setValue([0, 255])
214+
# elif len(parent.saturation) != parent.NZ:
215+
# parent.saturation = []
216+
# for r in range(3):
217+
# parent.saturation.append([])
218+
# for n in range(parent.NZ):
219+
# parent.saturation[-1].append([0, 255])
220+
# parent.sliders[r].setValue([0, 255])
221221
parent.compute_scale()
222222
parent.track_changes = []
223223

cellpose/gui/make_train.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def main():
7-
parser = argparse.ArgumentParser(description='cellpose parameters')
7+
parser = argparse.ArgumentParser(description='Make slices of XYZ image data for training. Assumes image is ZXYC unless specified otherwise using --channel_axis and --z_axis')
88

99
input_img_args = parser.add_argument_group("input image arguments")
1010
input_img_args.add_argument('--dir', default=[], type=str,
@@ -19,22 +19,22 @@ def main():
1919
input_img_args.add_argument('--img_filter', default=[], type=str,
2020
help='end string for images to run on')
2121
input_img_args.add_argument(
22-
'--channel_axis', default=None, type=int,
22+
'--channel_axis', default=-1, type=int,
2323
help='axis of image which corresponds to image channels')
24-
input_img_args.add_argument('--z_axis', default=None, type=int,
24+
input_img_args.add_argument('--z_axis', default=0, type=int,
2525
help='axis of image which corresponds to Z dimension')
2626
input_img_args.add_argument(
2727
'--chan', default=0, type=int, help=
28-
'channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s')
28+
'Deprecated')
2929
input_img_args.add_argument(
3030
'--chan2', default=0, type=int, help=
31-
'nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s'
31+
'Deprecated'
3232
)
3333
input_img_args.add_argument('--invert', action='store_true',
3434
help='invert grayscale channel')
3535
input_img_args.add_argument(
3636
'--all_channels', action='store_true', help=
37-
'use all channels in image if using own model and images with special channels')
37+
'deprecated')
3838
input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
3939
help="anisotropy of volume in 3D")
4040

@@ -77,8 +77,13 @@ def main():
7777
npm = ["YX", "ZY", "ZX"]
7878
for name in image_names:
7979
name0 = os.path.splitext(os.path.split(name)[-1])[0]
80-
img0 = io.imread(name)
81-
img0 = transforms.convert_image(img0, channels=[args.chan, args.chan2], channel_axis=args.channel_axis, z_axis=args.z_axis)
80+
img0 = io.imread_3D(name)
81+
try:
82+
img0 = transforms.convert_image(img0, channel_axis=args.channel_axis,
83+
z_axis=args.z_axis, do_3D=True)
84+
except ValueError:
85+
print('Error converting image. Did you provide the correct --channel_axis and --z_axis ?')
86+
8287
for p in range(3):
8388
img = img0.transpose(pm[p]).copy()
8489
print(npm[p], img[0].shape)

cellpose/io.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,17 @@ def imread_2D(img_file):
235235
# move channel to last dim:
236236
img = np.moveaxis(img, 0, -1)
237237

238+
nchan = img.shape[2]
239+
238240
if img.ndim == 3:
239-
if img.shape[2] == 3:
241+
if nchan == 3:
240242
# already has 3 channels
241243
return img
242244

243245
# ensure there are 3 channels
244246
img_out = np.zeros((img.shape[0], img.shape[1], 3), dtype=img.dtype)
245-
img_out[:, :, :img.shape[2]] = img
247+
copy_chan = min(3, nchan)
248+
img_out[:, :, :copy_chan] = img[:, :, :copy_chan]
246249

247250
elif img.ndim == 2:
248251
# add a channel dimension

tests/test_train.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from subprocess import check_output, STDOUT
33
import os, shutil
44
import torch
5+
from pathlib import Path
56

67

78
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@@ -41,3 +42,16 @@ def test_cli_train(data_dir):
4142
print(e)
4243
raise ValueError(e)
4344

45+
46+
def test_cli_make_train(data_dir):
47+
script_name = Path().resolve() / 'cellpose/gui/make_train.py'
48+
image_path = data_dir / '3D/gray_3D.tif'
49+
50+
cmd = f'python {script_name} --image_path {image_path}'
51+
res = check_output(cmd, stderr=STDOUT, shell=True)
52+
53+
# there should be 30 slices:
54+
files = [f for f in (data_dir / '3D/train/').iterdir() if 'gray_3D' in f.name]
55+
assert 30 == len(files)
56+
57+
shutil.rmtree((data_dir / '3D/train'))

0 commit comments

Comments
 (0)