Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions cellpose/gui/delete_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Utilities for deleting and relabeling GUI masks.
"""

import numpy as np


def normalize_remove_ids(remove_ids, ncells):
"""Return unique valid label IDs in descending order."""
remove_ids = np.asarray(remove_ids, dtype=np.int64).reshape(-1)
if remove_ids.size == 0 or ncells <= 0:
return np.zeros(0, dtype=np.int64)
valid = (remove_ids > 0) & (remove_ids <= int(ncells))
if not np.any(valid):
return np.zeros(0, dtype=np.int64)
remove_ids = np.unique(remove_ids[valid])
return np.sort(remove_ids)[::-1]


def batch_delete_reindex(cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids):
"""Delete labels and reindex all state in one pass.

Returns updated `(cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids, remove_mask)`.
"""
if cellpix.shape != outpix.shape:
raise ValueError("cellpix and outpix must have the same shape")

ncells = int(len(cellcolors) - 1)
remove_ids = normalize_remove_ids(remove_ids, ncells)
if remove_ids.size == 0:
remove_mask = np.zeros(ncells + 1, dtype=bool)
return (
cellpix,
outpix,
ismanual,
cellcolors,
list(zdraw),
remove_ids,
remove_mask,
)

remove_mask = np.zeros(ncells + 1, dtype=bool)
remove_mask[remove_ids] = True
keep_mask = ~remove_mask

lut_dtype = cellpix.dtype if np.issubdtype(cellpix.dtype, np.integer) else np.int64
relabel_map = np.cumsum(keep_mask, dtype=lut_dtype) - 1
relabel_map[remove_mask] = 0

cellpix = relabel_map[cellpix]
outpix = relabel_map[outpix]
ismanual = ismanual[keep_mask[1:]]
cellcolors = cellcolors[keep_mask]
zdraw = [z for z, keep in zip(zdraw, keep_mask[1:]) if keep]

return cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids, remove_mask
62 changes: 55 additions & 7 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from scipy.stats import mode
import cv2

from . import guiparts, menus, io
from . import guiparts, menus, io, delete_utils
from .. import models, core, dynamics, version, train
from ..utils import download_url_to_file, masks_to_outlines, diameters
from ..io import get_image_files, imsave, imread
Expand Down Expand Up @@ -1090,12 +1090,60 @@ def unselect_cell_multi(self, idx):
def remove_cell(self, idx):
if isinstance(idx, (int, np.integer)):
idx = [idx]
# because the function remove_single_cell updates the state of the cellpix and outpix arrays
# by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order
# so that the indices are correct
idx.sort(reverse=True)
for i in idx:
self.remove_single_cell(i)
idx = delete_utils.normalize_remove_ids(idx, self.ncells.get())
if idx.size == 0:
return

if idx.size == 1:
self.remove_single_cell(int(idx[0]))
else:
self.selected = 0
remove_mask = np.zeros(self.ncells.get() + 1, dtype=bool)
remove_mask[idx] = True

if self.currentZ < self.cellpix.shape[0]:
self.layerz[remove_mask[self.cellpix[self.currentZ]]] = np.array([0, 0, 0,
0])

if self.NZ == 1:
last_idx = int(idx[-1])
cp_last = self.cellpix[0] == last_idx
op_last = self.outpix[0] == last_idx
self.removed_cell = [
self.ismanual[last_idx - 1], self.cellcolors[last_idx],
np.nonzero(cp_last),
np.nonzero(op_last)
]
self.redo.setEnabled(True)

ar_all, ac_all = np.nonzero(remove_mask[self.cellpix[0]])
coord_map = {}
if ar_all.size > 0:
labels = self.cellpix[0, ar_all, ac_all]
order = np.argsort(labels, kind="mergesort")
labels = labels[order]
ar_all = ar_all[order]
ac_all = ac_all[order]
unique_labels, first_inds = np.unique(labels, return_index=True)
last_inds = np.append(first_inds[1:], labels.size)
for label, i0, i1 in zip(unique_labels, first_inds, last_inds):
coord_map[int(label)] = (ar_all[i0:i1], ac_all[i0:i1])

for i in idx:
ar, ac = coord_map.get(
int(i), (np.zeros(0, np.int64), np.zeros(0, np.int64)))
d = datetime.datetime.now()
self.track_changes.append(
[d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]])
print("GUI_INFO: removed cell %d" % (i - 1))
else:
for i in idx:
print("GUI_INFO: removed cell %d" % (i - 1))

(self.cellpix, self.outpix, self.ismanual, self.cellcolors, self.zdraw, _,
_) = delete_utils.batch_delete_reindex(self.cellpix, self.outpix,
self.ismanual, self.cellcolors,
self.zdraw, idx)
self.ncells -= len(idx) # _save_sets uses ncells
self.update_layer()

Expand Down
148 changes: 148 additions & 0 deletions tests/test_gui_delete_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
import importlib.util
from pathlib import Path


_DELETE_UTILS_PATH = (
Path(__file__).resolve().parents[1] / "cellpose/gui/delete_utils.py"
)
_DELETE_UTILS_SPEC = importlib.util.spec_from_file_location(
"cellpose_gui_delete_utils", _DELETE_UTILS_PATH
)
delete_utils = importlib.util.module_from_spec(_DELETE_UTILS_SPEC)
_DELETE_UTILS_SPEC.loader.exec_module(delete_utils)


def _legacy_remove_state(cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids):
cellpix = cellpix.copy()
outpix = outpix.copy()
ismanual = ismanual.copy()
cellcolors = cellcolors.copy()
zdraw = list(zdraw)

for idx in remove_ids:
cp = cellpix == idx
op = outpix == idx
cellpix[cp] = 0
outpix[op] = 0
cellpix[cellpix > idx] -= 1
outpix[outpix > idx] -= 1
ismanual = np.delete(ismanual, idx - 1)
cellcolors = np.delete(cellcolors, [idx], axis=0)
del zdraw[idx - 1]

return cellpix, outpix, ismanual, cellcolors, zdraw


def _random_state(seed, nz=1, ly=64, lx=64, ncells=40):
rng = np.random.default_rng(seed)
dtype = np.uint16 if ncells < 2**16 - 1 else np.uint32

cellpix = rng.integers(0, ncells + 1, size=(nz, ly, lx), dtype=dtype)
force_idx = rng.choice(cellpix.size, size=ncells, replace=False)
cellpix.flat[force_idx] = np.arange(1, ncells + 1, dtype=dtype)

outline_mask = rng.random(cellpix.shape) < 0.2
outpix = np.where(outline_mask, cellpix, 0).astype(dtype, copy=False)

ismanual = rng.integers(0, 2, size=ncells, dtype=np.uint8).astype(bool)
cellcolors = rng.integers(0, 256, size=(ncells + 1, 3), dtype=np.uint8)
cellcolors[0] = np.array([255, 255, 255], dtype=np.uint8)

zdraw = []
for _ in range(ncells):
nplanes = int(rng.integers(1, max(2, nz + 1)))
zdraw.append(list(rng.integers(0, max(1, nz), size=nplanes)))

return cellpix, outpix, ismanual, cellcolors, zdraw


def _assert_state_equal(expected, actual):
exp_cellpix, exp_outpix, exp_ismanual, exp_cellcolors, exp_zdraw = expected
got_cellpix, got_outpix, got_ismanual, got_cellcolors, got_zdraw = actual

assert np.array_equal(exp_cellpix, got_cellpix)
assert np.array_equal(exp_outpix, got_outpix)
assert np.array_equal(exp_ismanual, got_ismanual)
assert np.array_equal(exp_cellcolors, got_cellcolors)
assert len(exp_zdraw) == len(got_zdraw)
for z0, z1 in zip(exp_zdraw, got_zdraw):
assert np.array_equal(np.asarray(z0), np.asarray(z1))


def test_batch_delete_reindex_matches_legacy_small_example():
cellpix = np.array(
[[[1, 1, 2, 2], [1, 3, 3, 2], [4, 4, 5, 5], [4, 0, 5, 5]]], dtype=np.uint16
)
outpix = np.array(
[[[1, 0, 2, 0], [0, 3, 0, 2], [4, 0, 5, 0], [0, 0, 0, 5]]], dtype=np.uint16
)
ismanual = np.array([True, False, True, False, True])
cellcolors = np.array(
[[255, 255, 255], [10, 0, 0], [20, 0, 0], [30, 0, 0], [40, 0, 0], [50, 0, 0]],
dtype=np.uint8,
)
zdraw = [[0], [0], [0], [0], [0]]

remove_ids = np.array([5, 3, 2], dtype=np.int64)

expected = _legacy_remove_state(
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
)
got = delete_utils.batch_delete_reindex(
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
)[:5]
_assert_state_equal(expected, got)


def test_batch_delete_reindex_matches_legacy_random_2d():
for seed in range(20):
cellpix, outpix, ismanual, cellcolors, zdraw = _random_state(seed, nz=1)
rng = np.random.default_rng(seed + 1000)
ncells = len(cellcolors) - 1
remove_n = int(rng.integers(1, ncells + 1))
remove_ids = rng.choice(np.arange(1, ncells + 1), size=remove_n, replace=False)
remove_ids = delete_utils.normalize_remove_ids(remove_ids, ncells)

expected = _legacy_remove_state(
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
)
got = delete_utils.batch_delete_reindex(
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
)[:5]
_assert_state_equal(expected, got)


def test_batch_delete_reindex_matches_legacy_random_3d():
for seed in range(12):
cellpix, outpix, ismanual, cellcolors, zdraw = _random_state(seed + 100, nz=4)
rng = np.random.default_rng(seed + 2000)
ncells = len(cellcolors) - 1
remove_n = int(rng.integers(1, ncells + 1))
remove_ids = rng.choice(np.arange(1, ncells + 1), size=remove_n, replace=False)
remove_ids = delete_utils.normalize_remove_ids(remove_ids, ncells)

expected = _legacy_remove_state(
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
)
got = delete_utils.batch_delete_reindex(
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
)[:5]
_assert_state_equal(expected, got)


def test_batch_delete_reindex_noop_invalid_ids():
cellpix, outpix, ismanual, cellcolors, zdraw = _random_state(999, nz=1)
ncells = len(cellcolors) - 1
remove_ids = np.array([0, -1, ncells + 10], dtype=np.int64)

got = delete_utils.batch_delete_reindex(
cellpix, outpix, ismanual, cellcolors, zdraw, remove_ids
)
got_state = got[:5]
out_ids = got[5]
remove_mask = got[6]

_assert_state_equal((cellpix, outpix, ismanual, cellcolors, zdraw), got_state)
assert out_ids.size == 0
assert not remove_mask.any()