Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cellpose/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from cellpose.version import version, version_str
message = '''\n
Welcome to CellposeSAM, cellpose v4.0.1! The neural network component of
message = f'''\n
Welcome to CellposeSAM, cellpose v{version_str}! The neural network component of
CPSAM is much larger than in previous versions and CPU excution is slow.
We encourage users to use GPU/MPS if available. \n\n'''
print(message)
118 changes: 76 additions & 42 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ def __init__(self, image=None, logger=None):
self.ratio = 1.
self.reset()

# This needs to go after .reset() is called to get state fully set up:
self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked)

self.load_3D = False

# if called with image, load it
Expand Down Expand Up @@ -522,8 +525,11 @@ def make_buttons(self):
self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings)
self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9)

# connect edits to the diameter box to resizing the image:
# connect edits to image processing steps:
self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale)
self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob)
self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob)
self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob)

# Needed to do this for the drop down to not be open on startup
self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True)
Expand Down Expand Up @@ -951,12 +957,15 @@ def reset(self):
self.opacity = 128 # how opaque masks should be
self.outcolor = [200, 200, 255, 200]
self.NZ, self.Ly, self.Lx = 1, 256, 256
self.saturation = []
for r in range(3):
self.saturation.append([[0, 255] for n in range(self.NZ)])
self.sliders[r].setValue([0, 255])
self.sliders[r].setEnabled(False)
self.sliders[r].show()
self.saturation = self.saturation if hasattr(self, 'saturation') else []

# only adjust the saturation if auto-adjust is on:
if self.autobtn.isChecked():
for r in range(3):
self.saturation.append([[0, 255] for n in range(self.NZ)])
self.sliders[r].setValue([0, 255])
self.sliders[r].setEnabled(False)
self.sliders[r].show()
self.currentZ = 0
self.flows = [[], [], [], [], [[]]]
# masks matrix
Expand Down Expand Up @@ -1655,6 +1664,10 @@ def get_normalize_params(self):
normalize_params = {**normalize_default, **normalize_params}

return normalize_params

def compute_saturation_if_checked(self):
if self.autobtn.isChecked():
self.compute_saturation()

def compute_saturation(self, return_img=False):
norm = self.get_normalize_params()
Expand Down Expand Up @@ -1704,42 +1717,43 @@ def compute_saturation(self, return_img=False):
else:
img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered

self.saturation = []
for c in range(img_norm.shape[-1]):
self.saturation.append([])
if np.ptp(img_norm[..., c]) > 1e-3:
if norm3D:
x01 = np.percentile(img_norm[..., c], percentile[0])
x99 = np.percentile(img_norm[..., c], percentile[1])
if invert:
x01i = 255. - x99
x99i = 255. - x01
x01, x99 = x01i, x99i
for n in range(self.NZ):
self.saturation[-1].append([x01, x99])
else:
for z in range(self.NZ):
if self.NZ > 1:
x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
else:
x01 = np.percentile(img_norm[..., c], percentile[0])
x99 = np.percentile(img_norm[..., c], percentile[1])
if self.autobtn.isChecked():
self.saturation = []
for c in range(img_norm.shape[-1]):
self.saturation.append([])
if np.ptp(img_norm[..., c]) > 1e-3:
if norm3D:
x01 = np.percentile(img_norm[..., c], percentile[0])
x99 = np.percentile(img_norm[..., c], percentile[1])
if invert:
x01i = 255. - x99
x99i = 255. - x01
x01, x99 = x01i, x99i
self.saturation[-1].append([x01, x99])
else:
for n in range(self.NZ):
self.saturation[-1].append([0, 255.])
print(self.saturation[2][self.currentZ])
for n in range(self.NZ):
self.saturation[-1].append([x01, x99])
else:
for z in range(self.NZ):
if self.NZ > 1:
x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
else:
x01 = np.percentile(img_norm[..., c], percentile[0])
x99 = np.percentile(img_norm[..., c], percentile[1])
if invert:
x01i = 255. - x99
x99i = 255. - x01
x01, x99 = x01i, x99i
self.saturation[-1].append([x01, x99])
else:
for n in range(self.NZ):
self.saturation[-1].append([0, 255.])
print(self.saturation[2][self.currentZ])

if img_norm.shape[-1] == 1:
self.saturation.append(self.saturation[0])
self.saturation.append(self.saturation[0])
if img_norm.shape[-1] == 1:
self.saturation.append(self.saturation[0])
self.saturation.append(self.saturation[0])

self.autobtn.setChecked(True)
# self.autobtn.setChecked(True)
self.update_plot()


Expand Down Expand Up @@ -1838,14 +1852,31 @@ def compute_cprob(self):
if self.recompute_masks:
flow_threshold = self.segmentation_settings.flow_threshold
cellprob_threshold = self.segmentation_settings.cellprob_threshold
niter = self.segmentation_settings.niter
min_size = int(self.min_size.text()) if not isinstance(
self.min_size, int) else self.min_size

self.logger.info(
"computing masks with cell prob=%0.3f, flow error threshold=%0.3f" %
(cellprob_threshold, flow_threshold))

try:
dP = self.flows[2].squeeze()
cellprob = self.flows[3].squeeze()
except IndexError:
self.logger.error("Flows don't exist, try running model again.")
return

maski = dynamics.resize_and_compute_masks(
self.flows[4][:-1], self.flows[4][-1], p=self.flows[3].copy(),
cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold,
resize=self.cellpix.shape[-2:])[0]

dP=dP,
cellprob=cellprob,
niter=niter,
do_3D=self.load_3D,
min_size=min_size,
# max_size_fraction=min_size_fraction, # Leave as default
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold)

self.masksOn = True
if not self.OCheckBox.isChecked():
self.MCheckBox.setChecked(True)
Expand Down Expand Up @@ -1911,6 +1942,9 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
flows_new.append(flows[0].copy()) # RGB flow
flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) *
255).astype("uint8")) # cellprob
flows_new.append(flows[1].copy()) # XY flows
flows_new.append(flows[2].copy()) # original cellprob

if self.load_3D:
if stitch_threshold == 0.:
flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
Expand Down Expand Up @@ -1963,7 +1997,7 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
self.masksOn = True
self.MCheckBox.setChecked(True)
self.progress.setValue(100)
if self.restore != "filter" and self.restore is not None:
if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked():
self.compute_saturation()
if not do_3D and not stitch_threshold > 0:
self.recompute_masks = True
Expand Down
9 changes: 6 additions & 3 deletions cellpose/gui/guiparts.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def __init__(self, font):
grid_layout.addWidget(flow_threshold_qlabel, row, 0, 1, 2)
self.flow_threshold_box = QLineEdit()
self.flow_threshold_box.setText("0.4")
# self.flow_threshold_box.returnPressed.connect(self.compute_cprob) # TODO
self.flow_threshold_box.setFixedWidth(40)
self.flow_threshold_box.setFont(font)
grid_layout.addWidget(self.flow_threshold_box, row, 2, 1, 2)
Expand All @@ -286,7 +285,6 @@ def __init__(self, font):
grid_layout.addWidget(cellprob_qlabel, row, 4, 1, 2)
self.cellprob_threshold_box = QLineEdit()
self.cellprob_threshold_box.setText("0.0")
# self.cellprob_threshold.returnPressed.connect(self.compute_cprob) # TODO
self.cellprob_threshold_box.setFixedWidth(40)
self.cellprob_threshold_box.setFont(font)
self.cellprob_threshold_box.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
Expand Down Expand Up @@ -408,7 +406,12 @@ def cellprob_threshold(self):

@property
def niter(self):
return int(self.niter_box.text())
num = int(self.niter_box.text())
if num < 1:
self.niter_box.setText('200')
return 200
else:
return num



Expand Down
14 changes: 7 additions & 7 deletions cellpose/gui/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ def _initialize_images(parent, image, load_3D=False):
"GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)"
)
parent.compute_saturation()
elif len(parent.saturation) != parent.NZ:
parent.saturation = []
for r in range(3):
parent.saturation.append([])
for n in range(parent.NZ):
parent.saturation[-1].append([0, 255])
parent.sliders[r].setValue([0, 255])
# elif len(parent.saturation) != parent.NZ:
# parent.saturation = []
# for r in range(3):
# parent.saturation.append([])
# for n in range(parent.NZ):
# parent.saturation[-1].append([0, 255])
# parent.sliders[r].setValue([0, 255])
parent.compute_scale()
parent.track_changes = []

Expand Down
21 changes: 13 additions & 8 deletions cellpose/gui/make_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def main():
parser = argparse.ArgumentParser(description='cellpose parameters')
parser = argparse.ArgumentParser(description='Make slices of XYZ image data for training. Assumes image is ZXYC unless specified otherwise using --channel_axis and --z_axis')

input_img_args = parser.add_argument_group("input image arguments")
input_img_args.add_argument('--dir', default=[], type=str,
Expand All @@ -19,22 +19,22 @@ def main():
input_img_args.add_argument('--img_filter', default=[], type=str,
help='end string for images to run on')
input_img_args.add_argument(
'--channel_axis', default=None, type=int,
'--channel_axis', default=-1, type=int,
help='axis of image which corresponds to image channels')
input_img_args.add_argument('--z_axis', default=None, type=int,
input_img_args.add_argument('--z_axis', default=0, type=int,
help='axis of image which corresponds to Z dimension')
input_img_args.add_argument(
'--chan', default=0, type=int, help=
'channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s')
'Deprecated')
input_img_args.add_argument(
'--chan2', default=0, type=int, help=
'nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s'
'Deprecated'
)
input_img_args.add_argument('--invert', action='store_true',
help='invert grayscale channel')
input_img_args.add_argument(
'--all_channels', action='store_true', help=
'use all channels in image if using own model and images with special channels')
'deprecated')
input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
help="anisotropy of volume in 3D")

Expand Down Expand Up @@ -77,8 +77,13 @@ def main():
npm = ["YX", "ZY", "ZX"]
for name in image_names:
name0 = os.path.splitext(os.path.split(name)[-1])[0]
img0 = io.imread(name)
img0 = transforms.convert_image(img0, channels=[args.chan, args.chan2], channel_axis=args.channel_axis, z_axis=args.z_axis)
img0 = io.imread_3D(name)
try:
img0 = transforms.convert_image(img0, channel_axis=args.channel_axis,
z_axis=args.z_axis, do_3D=True)
except ValueError:
print('Error converting image. Did you provide the correct --channel_axis and --z_axis ?')

for p in range(3):
img = img0.transpose(pm[p]).copy()
print(npm[p], img[0].shape)
Expand Down
7 changes: 5 additions & 2 deletions cellpose/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,17 @@ def imread_2D(img_file):
# move channel to last dim:
img = np.moveaxis(img, 0, -1)

nchan = img.shape[2]

if img.ndim == 3:
if img.shape[2] == 3:
if nchan == 3:
# already has 3 channels
return img

# ensure there are 3 channels
img_out = np.zeros((img.shape[0], img.shape[1], 3), dtype=img.dtype)
img_out[:, :, :img.shape[2]] = img
copy_chan = min(3, nchan)
img_out[:, :, :copy_chan] = img[:, :, :copy_chan]

elif img.ndim == 2:
# add a channel dimension
Expand Down
14 changes: 14 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from subprocess import check_output, STDOUT
import os, shutil
import torch
from pathlib import Path


os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
Expand Down Expand Up @@ -41,3 +42,16 @@ def test_cli_train(data_dir):
print(e)
raise ValueError(e)


def test_cli_make_train(data_dir):
script_name = Path().resolve() / 'cellpose/gui/make_train.py'
image_path = data_dir / '3D/gray_3D.tif'

cmd = f'python {script_name} --image_path {image_path}'
res = check_output(cmd, stderr=STDOUT, shell=True)

# there should be 30 slices:
files = [f for f in (data_dir / '3D/train/').iterdir() if 'gray_3D' in f.name]
assert 30 == len(files)

shutil.rmtree((data_dir / '3D/train'))