Skip to content

Commit 537c65d

Browse files
adding other paper code
1 parent 9398cac commit 537c65d

File tree

4 files changed

+380
-42
lines changed

4 files changed

+380
-42
lines changed

cellpose/denoise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
MODEL_NAMES = []
2525
for ctype in ["cyto3", "cyto2", "nuclei"]:
26-
for ntype in ["denoise", "deblur", "upsample"]:
26+
for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
2727
MODEL_NAMES.append(f"{ntype}_{ctype}")
2828
if ctype != "cyto3":
2929
for ltype in ["per", "seg", "rec"]:

paper/3.0/analysis.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pathlib import Path
1010
import torch
1111
from torch import nn
12+
from tqdm import trange
1213

1314
# in same folder
1415
try:
@@ -299,6 +300,106 @@ def real_examples(folder):
299300
dat2["ap_n2v"] = ap_n2v
300301
np.save(root / "n2v_masks.npy", dat2)
301302

303+
def real_examples_ribo(root):
304+
navgs = [1, 2, 4, 8, 16, 32, 64]
305+
noisy = [[], [], [], [], [], [], []]
306+
clean = []
307+
for i in [1, 3, 6, 4, 5]:
308+
imgs = io.imread(Path(root) / f"denoise_{i:05d}_00001.tif")[:300]
309+
imgs = [imgs[:, :512, :512], imgs[:, 512:, :512], imgs[:, :512, 512:], imgs[:, 512:, 512:]]
310+
clean.extend([img.mean(axis=0) for img in imgs])
311+
for n, navg in enumerate(navgs):
312+
iavg = np.linspace(0, len(imgs[0])-1, navg+2).astype(int)[1:-1]
313+
noisy[n].extend(np.array([img[iavg].mean(axis=0) for img in imgs]))
314+
print(len(clean), len(noisy[0]))
315+
316+
thresholds = np.arange(0.5, 1.05, 0.05)
317+
diameter = 17
318+
normalize = True # {"tile_norm_blocksize": 80}
319+
seg_model = models.Cellpose(gpu=True, model_type="cyto2")
320+
model = denoise.DenoiseModel(gpu=True, model_type="denoise_cyto2")
321+
masks = seg_model.eval(clean, diameter=diameter, channels=[0,0],
322+
normalize=normalize)[0]
323+
ap_noisy = np.zeros((len(noisy), len(noisy[0]), len(thresholds)))
324+
ap_dn = np.zeros((len(noisy), len(noisy[0]), len(thresholds)))
325+
dat = {}
326+
dat["navgs"] = navgs
327+
dat["imgs_dn"] = []
328+
dat["masks_dn"] = []
329+
dat["masks_noisy"] = []
330+
dat["masks_clean"] = masks
331+
dat["noisy"] = noisy
332+
dat["clean"] = clean
333+
for n, imgs in enumerate(noisy):
334+
masks_noisy = seg_model.eval(imgs, diameter=diameter, channels=[0,0],
335+
normalize=normalize)[0]
336+
img_dn = model.eval(imgs, diameter=diameter, channels=[0,0],
337+
normalize=normalize)
338+
ap, tp, fp, fn = metrics.average_precision(masks, masks_noisy, threshold=thresholds)
339+
ap_noisy[n] = ap
340+
masks_dn = seg_model.eval(img_dn, diameter=diameter, channels=[0,0],
341+
normalize=normalize)[0]
342+
ap, tp, fp, fn = metrics.average_precision(masks, masks_dn, threshold=thresholds)
343+
ap_dn[n] = ap
344+
dat["imgs_dn"].append(img_dn)
345+
dat["masks_dn"].append(masks_dn)
346+
dat["masks_noisy"].append(masks_noisy)
347+
print(ap_noisy[n,:,0].mean(axis=0), ap_dn[n,:,0].mean(axis=0))
348+
dat["ap_noisy"] = ap_noisy
349+
dat["ap_dn"] = ap_dn
350+
np.save(Path(root) / "ribo_denoise.npy", dat)
351+
352+
dat = {}
353+
dat["navgs"] = navgs
354+
dat["imgs_n2s"] = []
355+
dat["masks_n2s"] = []
356+
dat["masks_clean"] = masks
357+
dat["noisy"] = noisy
358+
dat["clean"] = clean
359+
dat["ap_n2s"] = np.zeros((len(noisy), len(noisy[0]), len(thresholds)))
360+
361+
for n, imgs in enumerate(noisy):
362+
imgs_n2s = []
363+
for i in trange(len(imgs)):
364+
out = noise2self.train_per_image(imgs[i][np.newaxis,...].astype("float32"))
365+
imgs_n2s.append(out)
366+
imgs_n2s = np.array(imgs_n2s)
367+
masks_n2s = seg_model.eval(imgs_n2s, diameter=diameter, channels=[0,0])[0]
368+
ap, tp, fp, fn = metrics.average_precision(masks, masks_n2s, threshold=thresholds)
369+
dat["ap_n2s"][n] = ap
370+
dat["imgs_n2s"].append(imgs_n2s)
371+
dat["masks_n2s"].append(masks_n2s)
372+
print(n, ap.mean(axis=0)[[0, 5, 8]])
373+
374+
np.save(Path(root) / "ribo_denoise_n2s.npy", dat)
375+
376+
dat = {}
377+
dat["navgs"] = navgs
378+
dat["imgs_n2v"] = []
379+
dat["masks_n2v"] = []
380+
dat["masks_clean"] = masks
381+
dat["noisy"] = noisy
382+
dat["clean"] = clean
383+
dat["ap_n2v"] = np.zeros((len(noisy), len(noisy[0]), len(thresholds)))
384+
385+
for n, imgs in enumerate(noisy):
386+
imgs_n2v = []
387+
for i in trange(len(imgs)):
388+
out = noise2void.train_per_image(imgs[i].astype("float32"))
389+
imgs_n2v.append(out)
390+
imgs_n2v = np.array(imgs_n2v)
391+
masks_n2v = seg_model.eval(imgs_n2v, diameter=diameter, channels=[0,0],
392+
normalize=normalize)[0]
393+
ap, tp, fp, fn = metrics.average_precision(masks, masks_n2v, threshold=thresholds)
394+
#print(ap[:,0])
395+
dat["ap_n2v"][n] = ap
396+
dat["imgs_n2v"].append(imgs_n2v)
397+
dat["masks_n2v"].append(masks_n2v)
398+
print(n, ap.mean(axis=0)[[0, 5, 8]])
399+
400+
np.save(Path(root) / "ribo_denoise_n2v.npy", dat)
401+
402+
302403

303404
def specialist_training(root):
304405
""" root is path to specialist images (first 89 images of cyto2 and first 11 test images) """

paper/3.0/figures.py

Lines changed: 145 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,9 +1204,9 @@ def suppfig_specialist(folder, save_fig=True):
12041204

12051205
il = 0
12061206

1207-
fig = plt.figure(figsize=(9, 5), dpi=100)
1207+
fig = plt.figure(figsize=(9, 9), dpi=100)
12081208
yratio = 9 / 5
1209-
grid = plt.GridSpec(2, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1,
1209+
grid = plt.GridSpec(3, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1,
12101210
wspace=0.15, hspace=0.2)
12111211

12121212
titles = ["train - clean", "train - noisy", "test - noisy"]
@@ -1265,32 +1265,46 @@ def suppfig_specialist(folder, save_fig=True):
12651265
ax.set_xticks(np.arange(0.5, 1.05, 0.1))
12661266
ax.set_xlim([0.5, 1.0])
12671267

1268-
transl = mtransforms.ScaledTranslation(-10 / 72, 20 / 72, fig.dpi_scale_trans)
1268+
grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[1:, :], wspace=0.05,
1269+
hspace=0.1)
12691270

1270-
kk = [2, 3, 4, 10]
1271+
transl = mtransforms.ScaledTranslation(-10 / 72, 25 / 72, fig.dpi_scale_trans)
1272+
1273+
kk = [2, 3, 4, 6, 10]
12711274
iex = 8
1272-
ylim = [10, 310]
1273-
xlim = [100, 500]
1275+
ylim = [125, 512] # [0, 350]
1276+
xlim = [50, 325] # [100, 500]
12741277
legstr0[-1] = u"\u2013 Cellpose3 (per. + seg.)"
12751278
for j, k in enumerate(kk):
1276-
ax = plt.subplot(grid[1, j])
1277-
pos = ax.get_position().bounds
1278-
ax.set_position([pos[0], pos[1] - 0.07, pos[2], pos[3]])
1279-
img0 = imgs_all[k][iex].squeeze()
1280-
img0 *= 1.1
1281-
img0 = np.clip(img0, 0, 1)
1279+
outlines_gt = utils.outlines_list(masks_all[0][iex].T.copy(), multiprocessing=False)
1280+
for ii in range(2):
1281+
ax = plt.subplot(grid1[ii, j])
1282+
pos = ax.get_position().bounds
1283+
ax.set_position([pos[0], pos[1] - 0.07 + ii*0.03, pos[2], pos[3]])
1284+
img0 = imgs_all[k][iex].squeeze().T
1285+
masks0 = masks_all[k][iex].squeeze().T
1286+
img0 *= 1.
1287+
img0 = np.clip(img0, 0, 1)
12821288

1283-
ax.imshow(img0, cmap="gray", vmin=0, vmax=1)
1284-
ax.axis("off")
1285-
ax.set_ylim(ylim)
1286-
ax.set_xlim(xlim)
1287-
ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium")
1288-
ax.text(1, -0.04, f"AP@0.5 = {aps[k,iex,0] : 0.2f}", va="top", ha="right",
1289-
transform=ax.transAxes)
1290-
if j == 0:
1291-
il = plot_label(ltr, il, ax, transl, fs_title)
1292-
ax.text(0.02, 1.2, "Denoised test image", fontsize="large",
1293-
fontstyle="italic", transform=ax.transAxes)
1289+
ax.imshow(img0, cmap="gray", vmin=0, vmax=1)
1290+
if ii==1:
1291+
outlines = utils.outlines_list(masks0, multiprocessing=False)
1292+
for o in outlines_gt:
1293+
ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=2)
1294+
for o in outlines:
1295+
ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--")
1296+
ax.axis("off")
1297+
ax.set_ylim(ylim)
1298+
ax.set_xlim(xlim)
1299+
if ii==0:
1300+
ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium")
1301+
else:
1302+
ax.text(1, -0.04, f"AP@0.5 = {aps[k,iex,0] : 0.2f}", va="top", ha="right",
1303+
transform=ax.transAxes)
1304+
if j == 0 and ii==0:
1305+
il = plot_label(ltr, il, ax, transl, fs_title)
1306+
ax.text(0.02, 1.15, "Denoised test image", fontsize="large",
1307+
fontstyle="italic", transform=ax.transAxes)
12941308

12951309
print(aps.mean(axis=1)[:, [0, 5, 8]])
12961310

@@ -1493,9 +1507,9 @@ def fig6(folder, save_fig=True):
14931507

14941508
diams = [utils.diameters(lbl)[0] for lbl in lbls]
14951509

1496-
gen_model = "/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039"
1510+
gen_model = "oneclick_cyto3" #"/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039"
14971511
model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean,
1498-
pretrained_model=gen_model)
1512+
model_type=gen_model)
14991513
seg_model = models.CellposeModel(gpu=True, model_type="cyto3")
15001514
pscales = [1.5, 20., 1.5, 1., 5., 40., 3.]
15011515
denoise.deterministic()
@@ -1561,6 +1575,7 @@ def fig6(folder, save_fig=True):
15611575
legstr0 = ["", u"\u2013 noisy image", u"\u2013 original",
15621576
u"\u2013 noise-specific", "\u2013 data-specific", u"-- one-click"]
15631577
theight = [0, 0,4,3,2,1]
1578+
cstr = ["noisy\nimage", "blurry\nimage", "bilinear\nupsampled"]
15641579
for i in range(6):
15651580
ctype = "cellpose test set" if i < 3 else "nuclei test set"
15661581
noise_type = ["denoising", "deblurring", "upsampling"][i % 3]
@@ -1580,7 +1595,7 @@ def fig6(folder, save_fig=True):
15801595
if i == 1 or i == 4:
15811596
ax.text(0.5, 1.18, ctype, transform=ax.transAxes, ha="center",
15821597
fontsize="large")
1583-
1598+
ax.text(0.03, 0.03, cstr[i%3], transform=ax.transAxes, fontsize="small")
15841599
ax.set_ylim([0, 0.72])
15851600
ax.set_xticks(np.arange(0.5, 1.05, 0.25))
15861601
ax.set_xlim([0.5, 1.0])
@@ -1593,9 +1608,98 @@ def fig6(folder, save_fig=True):
15931608
]
15941609
colsj = cols0[[0, 1, -1]]
15951610

1596-
ly0 = 250
1611+
generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api,
1612+
titlesj, colsj, titlesi, j0=0, il=il)
1613+
1614+
if save_fig:
1615+
os.makedirs("figs/", exist_ok=True)
1616+
fig.savefig("figs/fig6.pdf", dpi=150)
1617+
1618+
def suppfig_generalist_examples(folder, save_fig=True):
1619+
cols0 = np.array([[0, 0, 0], [0, 0, 0], [0, 128, 0], [180, 229, 162],
1620+
[246, 198, 173], [192, 71, 29], ])
1621+
cols0 = cols0 / 255
1622+
titlesi = [
1623+
"Tissuenet", "Livecell", "Yeaz bright-field", "YeaZ phase-contrast",
1624+
"Omnipose phase-contrast", "Omnipose fluorescent", "DeepBacs"
1625+
]
1626+
colsj = cols0[[0, 1, -1]]
1627+
folders = [
1628+
"cyto2", "nuclei", "tissuenet", "livecell", "yeast_BF", "yeast_PhC",
1629+
"bact_phase", "bact_fluor", "deepbacs"
1630+
]
1631+
diam_mean = 30.
1632+
1633+
#iexs = [340, 50, 10, 5, 70, 2, 33]
1634+
iexs = [305, 1071, 0, 3, 70, 9, 31]
1635+
imgs, lbls = [[], [], []], []
1636+
masks = [[], [], []]
1637+
for f, iex in zip(folders[2:], iexs):
1638+
dat = np.load(Path(folder) / f"{f}_generalist_masks.npy",
1639+
allow_pickle=True).item()
1640+
img = dat["imgs"][iex].copy()
1641+
img = img[:1] if img.ndim > 2 else img
1642+
img = np.maximum(0, transforms.normalize99(img))
1643+
imgs[0].append(img)
1644+
masks[0].append(dat["masks_pred"][iex])
1645+
lbls.append(dat["masks"][iex].astype("uint16"))
1646+
1647+
diams = [utils.diameters(lbl)[0] for lbl in lbls]
15971648

1598-
transl = mtransforms.ScaledTranslation(-15 / 72, 30 / 72, fig.dpi_scale_trans)
1649+
gen_model = "oneclick_cyto3"
1650+
model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean,
1651+
model_type=gen_model)
1652+
seg_model = models.CellposeModel(gpu=True, model_type="cyto3")
1653+
1654+
fig = plt.figure(figsize=(14, 8), dpi=100)
1655+
grid = plt.GridSpec(4, 14, figure=fig, left=0.02, right=0.97, top=0.97, bottom=0.03)
1656+
1657+
for ii in range(2):
1658+
if ii==0:
1659+
titlesj = ["clean", "blurry", "deblurred (one-click)"]
1660+
else:
1661+
titlesj = ["clean", "downsampled", "upsampled (one-click)"]
1662+
masks[1] = []
1663+
masks[2] = []
1664+
imgs[1] = []
1665+
imgs[2] = []
1666+
sigmas = [5., 3., 7., 12., 5., 5., 3.]
1667+
ds = [6,4,8,8,6,6,6]
1668+
denoise.deterministic()
1669+
for i, img in tqdm(enumerate(imgs[0])):
1670+
img0 = torch.from_numpy(img.copy()).squeeze().unsqueeze(0).unsqueeze(0)
1671+
img0 = img0.float()
1672+
noisy0 = denoise.add_noise(img0, poisson=0., downsample=1. if ii==1 else 0,
1673+
blur=1., ds=ds[i] if ii==1 else 0,
1674+
sigma0 = sigmas[i] if ii==0 else sigmas[i]/2,
1675+
sigma1 = sigmas[i] if ii==0 else sigmas[i]/2,
1676+
pscale=120.).numpy().squeeze()
1677+
denoised0 = model.eval(noisy0, diameter=diams[i], normalize=True)
1678+
1679+
imgs[1].append(noisy0)
1680+
imgs[2].append(denoised0)
1681+
for j in range(1, 3):
1682+
masks[j].append(
1683+
seg_model.eval(
1684+
imgs[j][i], diameter=diams[i], channels=[0, 0], tile_overlap=0.5,
1685+
flow_threshold=0.4, augment=True, bsize=224,
1686+
niter=2000 if folders[i - 2] == "bact_phase" else None)[0])
1687+
api = np.array(
1688+
[metrics.average_precision(lbls, masks[i])[0][:, 0] for i in range(3)])
1689+
1690+
generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api,
1691+
titlesj, colsj, titlesi, j0=-1 + 2*ii, letter=True)
1692+
if save_fig:
1693+
os.makedirs("figs/", exist_ok=True)
1694+
fig.savefig("figs/suppfig_genex.pdf", dpi=150)
1695+
1696+
def generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api,
1697+
titlesj, colsj, titlesi, j0=0, ly0=250, letter=False, il=0):
1698+
if letter:
1699+
il = j0>0
1700+
transl = mtransforms.ScaledTranslation(-20 / 72, 15 / 72, fig.dpi_scale_trans)
1701+
else:
1702+
transl = mtransforms.ScaledTranslation(-20 / 72, 5 / 72, fig.dpi_scale_trans)
15991703
for i in range(len(imgs[0])):
16001704
ratio = diams[i] / 30.
16011705
d = utils.diameters(lbls[i])[0]
@@ -1608,20 +1712,18 @@ def fig6(folder, save_fig=True):
16081712
for j in range(1, 3):
16091713
img = np.clip(transforms.normalize99(imgs[j][i].copy().squeeze()), 0, 1)
16101714
for k in range(2):
1611-
ax = plt.subplot(grid[j, 2 * i + k])
1715+
ax = plt.subplot(grid[j+j0, 2 * i + k])
16121716
pos = ax.get_position().bounds
16131717
ax.set_position([
1614-
pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.07,
1718+
pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.08*(j0==0),
16151719
pos[2], pos[3]
16161720
])
16171721
if 1:
16181722
ax.imshow(img, cmap="gray", vmin=0,
1619-
vmax=0.35 if j == 1 and i == 2 else 1.0)
1723+
vmax=0.35 if j == 1 and i == 2 and j0==0 else 1.0)
16201724
if k == 1:
16211725
outlines = utils.outlines_list(masks[j][i],
16221726
multiprocessing=False)
1623-
#for o in outlines_gt:
1624-
# ax.plot(o[:,0], o[:,1], color=[0.7,0.4,1], lw=1, ls="-")
16251727
for o in outlines:
16261728
ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5,
16271729
ls="--")
@@ -1638,17 +1740,19 @@ def fig6(folder, save_fig=True):
16381740
if k == 0 and i == 0:
16391741
ax.text(-0.22, 0.5, titlesj[j], transform=ax.transAxes, va="center",
16401742
rotation=90, color=colsj[j], fontsize="medium")
1641-
if j == 0:
1743+
if j==1:
16421744
il = plot_label(ltr, il, ax, transl, fs_title)
1643-
ax.text(-0.0, 1.22, "Denoising examples from other datasets",
1745+
ax.text(-0.02, 1.05, "Denoising examples from other datasets",
16441746
fontstyle="italic", transform=ax.transAxes,
16451747
fontsize="large")
1646-
if k == 0 and j == 0:
1647-
ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes,
1648-
fontsize="medium")
1649-
if save_fig:
1650-
os.makedirs("figs/", exist_ok=True)
1651-
fig.savefig("figs/fig6.pdf", dpi=150)
1748+
if j==1 and letter:
1749+
ax.text(-0.0, 1.11, "Deblurring examples from other datasets" if j0==-1 else "Upsampling examples from other datasets",
1750+
fontstyle="italic", transform=ax.transAxes,
1751+
fontsize="large")
1752+
il = plot_label(ltr, il, ax, transl, fs_title)
1753+
#if k == 0 and (j == 0 or (j==1 and j0==0)):
1754+
#ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes,
1755+
# fontsize="medium")
16521756

16531757
def load_seg_generalist(folder):
16541758
folders = [

0 commit comments

Comments
 (0)