@@ -1204,9 +1204,9 @@ def suppfig_specialist(folder, save_fig=True):
12041204
12051205 il = 0
12061206
1207- fig = plt .figure (figsize = (9 , 5 ), dpi = 100 )
1207+ fig = plt .figure (figsize = (9 , 9 ), dpi = 100 )
12081208 yratio = 9 / 5
1209- grid = plt .GridSpec (2 , 4 , figure = fig , left = 0.02 , right = 0.96 , top = 0.96 , bottom = 0.1 ,
1209+ grid = plt .GridSpec (3 , 4 , figure = fig , left = 0.02 , right = 0.96 , top = 0.96 , bottom = 0.1 ,
12101210 wspace = 0.15 , hspace = 0.2 )
12111211
12121212 titles = ["train - clean" , "train - noisy" , "test - noisy" ]
@@ -1265,32 +1265,46 @@ def suppfig_specialist(folder, save_fig=True):
12651265 ax .set_xticks (np .arange (0.5 , 1.05 , 0.1 ))
12661266 ax .set_xlim ([0.5 , 1.0 ])
12671267
1268- transl = mtransforms .ScaledTranslation (- 10 / 72 , 20 / 72 , fig .dpi_scale_trans )
1268+ grid1 = matplotlib .gridspec .GridSpecFromSubplotSpec (2 , 5 , subplot_spec = grid [1 :, :], wspace = 0.05 ,
1269+ hspace = 0.1 )
12691270
1270- kk = [2 , 3 , 4 , 10 ]
1271+ transl = mtransforms .ScaledTranslation (- 10 / 72 , 25 / 72 , fig .dpi_scale_trans )
1272+
1273+ kk = [2 , 3 , 4 , 6 , 10 ]
12711274 iex = 8
1272- ylim = [10 , 310 ]
1273- xlim = [100 , 500 ]
1275+ ylim = [125 , 512 ] # [0, 350 ]
1276+ xlim = [50 , 325 ] # [ 100, 500]
12741277 legstr0 [- 1 ] = u"\u2013 Cellpose3 (per. + seg.)"
12751278 for j , k in enumerate (kk ):
1276- ax = plt .subplot (grid [1 , j ])
1277- pos = ax .get_position ().bounds
1278- ax .set_position ([pos [0 ], pos [1 ] - 0.07 , pos [2 ], pos [3 ]])
1279- img0 = imgs_all [k ][iex ].squeeze ()
1280- img0 *= 1.1
1281- img0 = np .clip (img0 , 0 , 1 )
1279+ outlines_gt = utils .outlines_list (masks_all [0 ][iex ].T .copy (), multiprocessing = False )
1280+ for ii in range (2 ):
1281+ ax = plt .subplot (grid1 [ii , j ])
1282+ pos = ax .get_position ().bounds
1283+ ax .set_position ([pos [0 ], pos [1 ] - 0.07 + ii * 0.03 , pos [2 ], pos [3 ]])
1284+ img0 = imgs_all [k ][iex ].squeeze ().T
1285+ masks0 = masks_all [k ][iex ].squeeze ().T
1286+ img0 *= 1.
1287+ img0 = np .clip (img0 , 0 , 1 )
12821288
1283- ax .imshow (img0 , cmap = "gray" , vmin = 0 , vmax = 1 )
1284- ax .axis ("off" )
1285- ax .set_ylim (ylim )
1286- ax .set_xlim (xlim )
1287- ax .set_title (legstr0 [k ][2 :], color = cols0 [k ], fontsize = "medium" )
1288- ax .text (1 , - 0.04 , f"AP@0.5 = { aps [k ,iex ,0 ] : 0.2f} " , va = "top" , ha = "right" ,
1289- transform = ax .transAxes )
1290- if j == 0 :
1291- il = plot_label (ltr , il , ax , transl , fs_title )
1292- ax .text (0.02 , 1.2 , "Denoised test image" , fontsize = "large" ,
1293- fontstyle = "italic" , transform = ax .transAxes )
1289+ ax .imshow (img0 , cmap = "gray" , vmin = 0 , vmax = 1 )
1290+ if ii == 1 :
1291+ outlines = utils .outlines_list (masks0 , multiprocessing = False )
1292+ for o in outlines_gt :
1293+ ax .plot (o [:, 0 ], o [:, 1 ], color = [0.7 ,0.4 ,1 ], lw = 2 )
1294+ for o in outlines :
1295+ ax .plot (o [:, 0 ], o [:, 1 ], color = [1 , 1 , 0.3 ], lw = 1.5 , ls = "--" )
1296+ ax .axis ("off" )
1297+ ax .set_ylim (ylim )
1298+ ax .set_xlim (xlim )
1299+ if ii == 0 :
1300+ ax .set_title (legstr0 [k ][2 :], color = cols0 [k ], fontsize = "medium" )
1301+ else :
1302+ ax .text (1 , - 0.04 , f"AP@0.5 = { aps [k ,iex ,0 ] : 0.2f} " , va = "top" , ha = "right" ,
1303+ transform = ax .transAxes )
1304+ if j == 0 and ii == 0 :
1305+ il = plot_label (ltr , il , ax , transl , fs_title )
1306+ ax .text (0.02 , 1.15 , "Denoised test image" , fontsize = "large" ,
1307+ fontstyle = "italic" , transform = ax .transAxes )
12941308
12951309 print (aps .mean (axis = 1 )[:, [0 , 5 , 8 ]])
12961310
@@ -1493,9 +1507,9 @@ def fig6(folder, save_fig=True):
14931507
14941508 diams = [utils .diameters (lbl )[0 ] for lbl in lbls ]
14951509
1496- gen_model = "/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039"
1510+ gen_model = "oneclick_cyto3" #" /home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039"
14971511 model = denoise .DenoiseModel (gpu = True , nchan = 1 , diam_mean = diam_mean ,
1498- pretrained_model = gen_model )
1512+ model_type = gen_model )
14991513 seg_model = models .CellposeModel (gpu = True , model_type = "cyto3" )
15001514 pscales = [1.5 , 20. , 1.5 , 1. , 5. , 40. , 3. ]
15011515 denoise .deterministic ()
@@ -1561,6 +1575,7 @@ def fig6(folder, save_fig=True):
15611575 legstr0 = ["" , u"\u2013 noisy image" , u"\u2013 original" ,
15621576 u"\u2013 noise-specific" , "\u2013 data-specific" , u"-- one-click" ]
15631577 theight = [0 , 0 ,4 ,3 ,2 ,1 ]
1578+ cstr = ["noisy\n image" , "blurry\n image" , "bilinear\n upsampled" ]
15641579 for i in range (6 ):
15651580 ctype = "cellpose test set" if i < 3 else "nuclei test set"
15661581 noise_type = ["denoising" , "deblurring" , "upsampling" ][i % 3 ]
@@ -1580,7 +1595,7 @@ def fig6(folder, save_fig=True):
15801595 if i == 1 or i == 4 :
15811596 ax .text (0.5 , 1.18 , ctype , transform = ax .transAxes , ha = "center" ,
15821597 fontsize = "large" )
1583-
1598+ ax . text ( 0.03 , 0.03 , cstr [ i % 3 ], transform = ax . transAxes , fontsize = "small" )
15841599 ax .set_ylim ([0 , 0.72 ])
15851600 ax .set_xticks (np .arange (0.5 , 1.05 , 0.25 ))
15861601 ax .set_xlim ([0.5 , 1.0 ])
@@ -1593,9 +1608,98 @@ def fig6(folder, save_fig=True):
15931608 ]
15941609 colsj = cols0 [[0 , 1 , - 1 ]]
15951610
1596- ly0 = 250
1611+ generalist_restoration_panels (fig , grid , imgs , lbls , masks , diams , api ,
1612+ titlesj , colsj , titlesi , j0 = 0 , il = il )
1613+
1614+ if save_fig :
1615+ os .makedirs ("figs/" , exist_ok = True )
1616+ fig .savefig ("figs/fig6.pdf" , dpi = 150 )
1617+
1618+ def suppfig_generalist_examples (folder , save_fig = True ):
1619+ cols0 = np .array ([[0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 128 , 0 ], [180 , 229 , 162 ],
1620+ [246 , 198 , 173 ], [192 , 71 , 29 ], ])
1621+ cols0 = cols0 / 255
1622+ titlesi = [
1623+ "Tissuenet" , "Livecell" , "Yeaz bright-field" , "YeaZ phase-contrast" ,
1624+ "Omnipose phase-contrast" , "Omnipose fluorescent" , "DeepBacs"
1625+ ]
1626+ colsj = cols0 [[0 , 1 , - 1 ]]
1627+ folders = [
1628+ "cyto2" , "nuclei" , "tissuenet" , "livecell" , "yeast_BF" , "yeast_PhC" ,
1629+ "bact_phase" , "bact_fluor" , "deepbacs"
1630+ ]
1631+ diam_mean = 30.
1632+
1633+ #iexs = [340, 50, 10, 5, 70, 2, 33]
1634+ iexs = [305 , 1071 , 0 , 3 , 70 , 9 , 31 ]
1635+ imgs , lbls = [[], [], []], []
1636+ masks = [[], [], []]
1637+ for f , iex in zip (folders [2 :], iexs ):
1638+ dat = np .load (Path (folder ) / f"{ f } _generalist_masks.npy" ,
1639+ allow_pickle = True ).item ()
1640+ img = dat ["imgs" ][iex ].copy ()
1641+ img = img [:1 ] if img .ndim > 2 else img
1642+ img = np .maximum (0 , transforms .normalize99 (img ))
1643+ imgs [0 ].append (img )
1644+ masks [0 ].append (dat ["masks_pred" ][iex ])
1645+ lbls .append (dat ["masks" ][iex ].astype ("uint16" ))
1646+
1647+ diams = [utils .diameters (lbl )[0 ] for lbl in lbls ]
15971648
1598- transl = mtransforms .ScaledTranslation (- 15 / 72 , 30 / 72 , fig .dpi_scale_trans )
1649+ gen_model = "oneclick_cyto3"
1650+ model = denoise .DenoiseModel (gpu = True , nchan = 1 , diam_mean = diam_mean ,
1651+ model_type = gen_model )
1652+ seg_model = models .CellposeModel (gpu = True , model_type = "cyto3" )
1653+
1654+ fig = plt .figure (figsize = (14 , 8 ), dpi = 100 )
1655+ grid = plt .GridSpec (4 , 14 , figure = fig , left = 0.02 , right = 0.97 , top = 0.97 , bottom = 0.03 )
1656+
1657+ for ii in range (2 ):
1658+ if ii == 0 :
1659+ titlesj = ["clean" , "blurry" , "deblurred (one-click)" ]
1660+ else :
1661+ titlesj = ["clean" , "downsampled" , "upsampled (one-click)" ]
1662+ masks [1 ] = []
1663+ masks [2 ] = []
1664+ imgs [1 ] = []
1665+ imgs [2 ] = []
1666+ sigmas = [5. , 3. , 7. , 12. , 5. , 5. , 3. ]
1667+ ds = [6 ,4 ,8 ,8 ,6 ,6 ,6 ]
1668+ denoise .deterministic ()
1669+ for i , img in tqdm (enumerate (imgs [0 ])):
1670+ img0 = torch .from_numpy (img .copy ()).squeeze ().unsqueeze (0 ).unsqueeze (0 )
1671+ img0 = img0 .float ()
1672+ noisy0 = denoise .add_noise (img0 , poisson = 0. , downsample = 1. if ii == 1 else 0 ,
1673+ blur = 1. , ds = ds [i ] if ii == 1 else 0 ,
1674+ sigma0 = sigmas [i ] if ii == 0 else sigmas [i ]/ 2 ,
1675+ sigma1 = sigmas [i ] if ii == 0 else sigmas [i ]/ 2 ,
1676+ pscale = 120. ).numpy ().squeeze ()
1677+ denoised0 = model .eval (noisy0 , diameter = diams [i ], normalize = True )
1678+
1679+ imgs [1 ].append (noisy0 )
1680+ imgs [2 ].append (denoised0 )
1681+ for j in range (1 , 3 ):
1682+ masks [j ].append (
1683+ seg_model .eval (
1684+ imgs [j ][i ], diameter = diams [i ], channels = [0 , 0 ], tile_overlap = 0.5 ,
1685+ flow_threshold = 0.4 , augment = True , bsize = 224 ,
1686+ niter = 2000 if folders [i - 2 ] == "bact_phase" else None )[0 ])
1687+ api = np .array (
1688+ [metrics .average_precision (lbls , masks [i ])[0 ][:, 0 ] for i in range (3 )])
1689+
1690+ generalist_restoration_panels (fig , grid , imgs , lbls , masks , diams , api ,
1691+ titlesj , colsj , titlesi , j0 = - 1 + 2 * ii , letter = True )
1692+ if save_fig :
1693+ os .makedirs ("figs/" , exist_ok = True )
1694+ fig .savefig ("figs/suppfig_genex.pdf" , dpi = 150 )
1695+
1696+ def generalist_restoration_panels (fig , grid , imgs , lbls , masks , diams , api ,
1697+ titlesj , colsj , titlesi , j0 = 0 , ly0 = 250 , letter = False , il = 0 ):
1698+ if letter :
1699+ il = j0 > 0
1700+ transl = mtransforms .ScaledTranslation (- 20 / 72 , 15 / 72 , fig .dpi_scale_trans )
1701+ else :
1702+ transl = mtransforms .ScaledTranslation (- 20 / 72 , 5 / 72 , fig .dpi_scale_trans )
15991703 for i in range (len (imgs [0 ])):
16001704 ratio = diams [i ] / 30.
16011705 d = utils .diameters (lbls [i ])[0 ]
@@ -1608,20 +1712,18 @@ def fig6(folder, save_fig=True):
16081712 for j in range (1 , 3 ):
16091713 img = np .clip (transforms .normalize99 (imgs [j ][i ].copy ().squeeze ()), 0 , 1 )
16101714 for k in range (2 ):
1611- ax = plt .subplot (grid [j , 2 * i + k ])
1715+ ax = plt .subplot (grid [j + j0 , 2 * i + k ])
16121716 pos = ax .get_position ().bounds
16131717 ax .set_position ([
1614- pos [0 ] + 0.003 * i - 0.00 * k , pos [1 ] - (2 - j ) * 0.025 - 0.07 ,
1718+ pos [0 ] + 0.003 * i - 0.00 * k , pos [1 ] - (2 - j ) * 0.025 - 0.08 * ( j0 == 0 ) ,
16151719 pos [2 ], pos [3 ]
16161720 ])
16171721 if 1 :
16181722 ax .imshow (img , cmap = "gray" , vmin = 0 ,
1619- vmax = 0.35 if j == 1 and i == 2 else 1.0 )
1723+ vmax = 0.35 if j == 1 and i == 2 and j0 == 0 else 1.0 )
16201724 if k == 1 :
16211725 outlines = utils .outlines_list (masks [j ][i ],
16221726 multiprocessing = False )
1623- #for o in outlines_gt:
1624- # ax.plot(o[:,0], o[:,1], color=[0.7,0.4,1], lw=1, ls="-")
16251727 for o in outlines :
16261728 ax .plot (o [:, 0 ], o [:, 1 ], color = [1 , 1 , 0.3 ], lw = 1.5 ,
16271729 ls = "--" )
@@ -1638,17 +1740,19 @@ def fig6(folder, save_fig=True):
16381740 if k == 0 and i == 0 :
16391741 ax .text (- 0.22 , 0.5 , titlesj [j ], transform = ax .transAxes , va = "center" ,
16401742 rotation = 90 , color = colsj [j ], fontsize = "medium" )
1641- if j == 0 :
1743+ if j == 1 :
16421744 il = plot_label (ltr , il , ax , transl , fs_title )
1643- ax .text (- 0.0 , 1.22 , "Denoising examples from other datasets" ,
1745+ ax .text (- 0.02 , 1.05 , "Denoising examples from other datasets" ,
16441746 fontstyle = "italic" , transform = ax .transAxes ,
16451747 fontsize = "large" )
1646- if k == 0 and j == 0 :
1647- ax .text (0.0 , 1.05 , titlesi [i ], transform = ax .transAxes ,
1648- fontsize = "medium" )
1649- if save_fig :
1650- os .makedirs ("figs/" , exist_ok = True )
1651- fig .savefig ("figs/fig6.pdf" , dpi = 150 )
1748+ if j == 1 and letter :
1749+ ax .text (- 0.0 , 1.11 , "Deblurring examples from other datasets" if j0 == - 1 else "Upsampling examples from other datasets" ,
1750+ fontstyle = "italic" , transform = ax .transAxes ,
1751+ fontsize = "large" )
1752+ il = plot_label (ltr , il , ax , transl , fs_title )
1753+ #if k == 0 and (j == 0 or (j==1 and j0==0)):
1754+ #ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes,
1755+ # fontsize="medium")
16521756
16531757def load_seg_generalist (folder ):
16541758 folders = [
0 commit comments