Skip to content

Commit ede96b9

Browse files
Merge pull request #65 from computational-cell-analytics/figure_update
Add crops of baseline segmentation
2 parents 25f7d73 + ef6761c commit ede96b9

File tree

6 files changed

+167
-61
lines changed

6 files changed

+167
-61
lines changed

scripts/figures/plot_fig2.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,116 @@
44
import numpy as np
55
import pandas as pd
66
import matplotlib.pyplot as plt
7+
import tifffile
8+
from matplotlib import colors
9+
from skimage.segmentation import find_boundaries
710

811
from util import literature_reference_values
912

1013
png_dpi = 300
1114

1215

16+
def scramble_instance_labels(arr):
17+
"""Scramble indexes of instance segmentation to avoid neighboring colors.
18+
"""
19+
unique = list(np.unique(arr)[1:])
20+
rng = np.random.default_rng(seed=42)
21+
new_list = rng.uniform(1, len(unique) + 1, size=(len(unique)))
22+
new_arr = np.zeros(arr.shape)
23+
for old_id, new_id in zip(unique, new_list):
24+
new_arr[arr == old_id] = new_id
25+
return new_arr
26+
27+
28+
def plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba=[0, 0, 0, 0.5], plot=False):
29+
seg = tifffile.imread(seg_path)
30+
if len(seg.shape) == 3:
31+
seg = seg[10, xlim1:xlim2, ylim1:ylim2]
32+
else:
33+
seg = seg[xlim1:xlim2, ylim1:ylim2]
34+
35+
img = tifffile.imread(img_path)
36+
img = img[10, xlim1:xlim2, ylim1:ylim2]
37+
38+
# create color map with random distribution for coloring instance segmentation
39+
unique = list(np.unique(seg)[1:])
40+
n_instances = len(unique)
41+
42+
seg = scramble_instance_labels(seg)
43+
44+
rng = np.random.default_rng(seed=42) # fixed seed for reproducibility
45+
colors_array = rng.uniform(0, 1, size=(n_instances, 4)) # RGBA values in [0,1]
46+
colors_array[:, 3] = 1.0 # full alpha
47+
colors_array[0, 3] = 0.0 # make label 0 transparent (background)
48+
cmap = colors.ListedColormap(colors_array)
49+
50+
boundaries = find_boundaries(seg, mode="inner")
51+
boundary_overlay = np.zeros((*boundaries.shape, 4))
52+
53+
boundary_overlay[boundaries] = boundary_rgba # RGBA = black
54+
55+
fig, ax = plt.subplots(figsize=(6, 6))
56+
ax.imshow(img, cmap="gray")
57+
ax.imshow(seg, cmap=cmap, alpha=0.5, interpolation="nearest")
58+
ax.imshow(boundary_overlay)
59+
ax.axis("off")
60+
plt.tight_layout()
61+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
62+
63+
if plot:
64+
plt.show()
65+
else:
66+
plt.close()
67+
68+
69+
def fig_02b_sgn(save_dir, plot=False):
70+
"""Plot crops of SGN segmentation of CochleaNet, Cellpose and micro-sam.
71+
"""
72+
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
73+
val_sgn_dir = f"{cochlea_dir}/predictions/val_sgn"
74+
image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationSGNs/for_consensus_annotation"
75+
76+
crop_name = "MLR169R_PV_z3420_allturns_full"
77+
img_path = os.path.join(image_dir, f"{crop_name}.tif")
78+
79+
xlim1 = 2000
80+
xlim2 = 2500
81+
ylim1 = 3100
82+
ylim2 = 3600
83+
boundary_rgba = [1, 1, 1, 0.5]
84+
85+
for seg_net in ["distance_unet", "cellpose-sam", "micro-sam"]:
86+
save_path = os.path.join(save_dir, f"fig_02b_sgn_{seg_net}.png")
87+
seg_dir = os.path.join(val_sgn_dir, seg_net)
88+
seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif")
89+
90+
plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)
91+
92+
93+
def fig_02b_ihc(save_dir, plot=False):
94+
"""Plot crops of IHC segmentation of CochleaNet, Cellpose and micro-sam.
95+
"""
96+
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
97+
val_sgn_dir = f"{cochlea_dir}/predictions/val_ihc"
98+
image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationIHCs"
99+
100+
crop_name = "MLR226L_VGlut3_z1200_3turns_full"
101+
img_path = os.path.join(image_dir, f"{crop_name}.tif")
102+
103+
xlim1 = 1900
104+
xlim2 = 2400
105+
ylim1 = 2000
106+
ylim2 = 2500
107+
boundary_rgba = [1, 1, 1, 0.5]
108+
109+
for seg_net in ["distance_unet_v4b", "cellpose-sam", "micro-sam"]:
110+
save_path = os.path.join(save_dir, f"fig_02b_ihc_{seg_net}.png")
111+
seg_dir = os.path.join(val_sgn_dir, seg_net)
112+
seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif")
113+
114+
plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)
115+
116+
13117
def fig_02c(save_path, plot=False, all_versions=False):
14118
"""Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual),
15119
IHC (distance U-Net, manual), and synapse detection (U-Net).
@@ -299,6 +403,8 @@ def main():
299403
os.makedirs(args.figure_dir, exist_ok=True)
300404

301405
# Panel C: Evaluation of the segmentation results:
406+
fig_02b_sgn(save_dir=args.figure_dir, plot=args.plot)
407+
fig_02b_ihc(save_dir=args.figure_dir, plot=args.plot)
302408
fig_02c(save_path=os.path.join(args.figure_dir, "fig_02c"), plot=args.plot, all_versions=False)
303409

304410
# Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC

scripts/figures/plot_fig3.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import os
33
import imageio.v3 as imageio
4-
from glob import glob
54

65
import matplotlib.pyplot as plt
76
import numpy as np
@@ -72,7 +71,10 @@ def fig_03b(save_path):
7271

7372

7473
def fig_03c_rl(save_path, plot=False):
75-
tables = glob("./ihc_counts/ihc_count_M_LR*.tsv")
74+
ihc_version = "ihc_counts_v4c"
75+
synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}"
76+
tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_LR" in entry.name]
77+
7678
fig, ax = plt.subplots(figsize=(8, 4))
7779

7880
width = 50 # micron
@@ -102,7 +104,9 @@ def fig_03c_rl(save_path, plot=False):
102104

103105

104106
def fig_03c_octave(save_path, plot=False):
105-
tables = glob("./ihc_counts/ihc_count_M_LR*.tsv")
107+
ihc_version = "ihc_counts_v4c"
108+
synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}"
109+
tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_LR" in entry.name]
106110

107111
result = {"cochlea": [], "octave_band": [], "value": []}
108112
for tab_path in tables:
@@ -134,7 +138,7 @@ def fig_03c_octave(save_path, plot=False):
134138

135139
ax.set_ylabel("Average Ribbon Synapse Count per IHC")
136140
ax.set_title("Ribbon synapse count per octave band")
137-
ax.legend(title="Cochlea")
141+
plt.legend(title="Cochlea")
138142

139143
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
140144
if plot:

scripts/figures/plot_fig4.py

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,21 @@
1313
INTENSITY_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/tables/measurements2" # noqa
1414

1515
# The cochlea for the CHReef analysis.
16-
COCHLEAE = [
17-
"M_LR_000143_L",
18-
"M_LR_000144_L",
19-
"M_LR_000145_L",
20-
"M_LR_000153_L",
21-
"M_LR_000155_L",
22-
"M_LR_000189_L",
23-
"M_LR_000143_R",
24-
"M_LR_000144_R",
25-
"M_LR_000145_R",
26-
"M_LR_000153_R",
27-
"M_LR_000155_R",
28-
"M_LR_000189_R",
29-
]
30-
31-
COCHLEAE_GERBIL = [
32-
"G_EK_000049_L",
33-
"G_EK_000049_R",
34-
]
35-
36-
37-
COCHLEAE_ALIAS = {
38-
"M_LR_000143_L": "M0L",
39-
"M_LR_000144_L": "M05L",
40-
"M_LR_000145_L": "M06L",
41-
"M_LR_000153_L": "M07L",
42-
"M_LR_000155_L": "M08L",
43-
"M_LR_000189_L": "M09L",
44-
"M_LR_000143_R": "M0R",
45-
"M_LR_000144_R": "M05R",
46-
"M_LR_000145_R": "M06R",
47-
"M_LR_000153_R": "M07R",
48-
"M_LR_000155_R": "M08R",
49-
"M_LR_000189_R": "M09R",
50-
"G_EK_000049_L": "G1L",
51-
"G_EK_000049_R": "G1R",
16+
COCHLEAE_DICT = {
17+
"M_LR_000143_L": {"alias": "M0L", "component": [1]},
18+
"M_LR_000144_L": {"alias": "M05L", "component": [1]},
19+
"M_LR_000145_L": {"alias": "M06L", "component": [1]},
20+
"M_LR_000153_L": {"alias": "M07L", "component": [1]},
21+
"M_LR_000155_L": {"alias": "M08L", "component": [1, 2, 3]},
22+
"M_LR_000189_L": {"alias": "M09L", "component": [1]},
23+
"M_LR_000143_R": {"alias": "M0R", "component": [1]},
24+
"M_LR_000144_R": {"alias": "M05R", "component": [1]},
25+
"M_LR_000145_R": {"alias": "M06R", "component": [1]},
26+
"M_LR_000153_R": {"alias": "M07R", "component": [1]},
27+
"M_LR_000155_R": {"alias": "M08R", "component": [1]},
28+
"M_LR_000189_R": {"alias": "M09R", "component": [1]},
29+
"G_EK_000049_L": {"alias": "G1L", "component": [1, 3, 4, 5]},
30+
"G_EK_000049_R": {"alias": "G1R", "component": [1, 2]},
5231
}
5332

5433
png_dpi = 300
@@ -60,10 +39,10 @@ def get_chreef_data(animal="mouse"):
6039

6140
if animal == "mouse":
6241
cache_path = "./chreef_data.pkl"
63-
cochleae = COCHLEAE
42+
cochleae = [key for key in COCHLEAE_DICT.keys() if "M_" in key]
6443
else:
6544
cache_path = "./chreef_data_gerbil.pkl"
66-
cochleae = COCHLEAE_GERBIL
45+
cochleae = [key for key in COCHLEAE_DICT.keys() if "G_" in key]
6746

6847
if os.path.exists(cache_path):
6948
with open(cache_path, "rb") as f:
@@ -83,7 +62,9 @@ def get_chreef_data(animal="mouse"):
8362
table = pd.read_csv(table_content, sep="\t")
8463

8564
# May need to be adjusted for some cochleae.
86-
table = table[table.component_labels == 1]
65+
component_labels = COCHLEAE_DICT[cochlea]["component"]
66+
print(cochlea, component_labels)
67+
table = table[table.component_labels.isin(component_labels)]
8768
# The relevant values for analysis.
8869
try:
8970
values = table[["label_id", "length[µm]", "frequency[kHz]", "marker_labels"]]
@@ -136,7 +117,7 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr
136117

137118
# TODO have central function for alias for all plots?
138119
if use_alias:
139-
alias = [COCHLEAE_ALIAS[k] for k in chreef_data.keys()]
120+
alias = [COCHLEAE_DICT[k]["alias"] for k in chreef_data.keys()]
140121
else:
141122
alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()]
142123

@@ -206,7 +187,7 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa
206187
"""Transduction efficiency per cochlea.
207188
"""
208189
if use_alias:
209-
alias = [COCHLEAE_ALIAS[k] for k in chreef_data.keys()]
190+
alias = [COCHLEAE_DICT[k]["alias"] for k in chreef_data.keys()]
210191
else:
211192
alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()]
212193

@@ -237,7 +218,7 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa
237218
main_label_size = 20
238219
sub_label_size = 16
239220
main_tick_size = 12
240-
legendsize = 16
221+
legendsize = 16 if intensity else 12
241222

242223
label = "Intensity" if intensity else "Transduction efficiency"
243224
if plot_by_side:
@@ -274,7 +255,7 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali
274255
result = {"cochlea": [], "octave_band": [], "value": []}
275256
for name, values in chreef_data.items():
276257
if use_alias:
277-
alias = COCHLEAE_ALIAS[name]
258+
alias = COCHLEAE_DICT[name]["alias"]
278259
else:
279260
alias = name.replace("_", "").replace("0", "")
280261

scripts/figures/plot_fig6.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import matplotlib.pyplot as plt
55

6+
from util import literature_reference_values_gerbil
7+
68
png_dpi = 300
79

810

@@ -27,9 +29,10 @@ def fig_06a(save_path, plot=False):
2729
# Labels and formatting
2830
ax[0].set_xticklabels(["SGN"], fontsize=main_label_size)
2931

30-
ylim0 = 12000
31-
ylim1 = 22500
32-
y_ticks = [i for i in range(ylim0, ylim1 + 1, 2000)]
32+
ylim0 = 14000
33+
ylim1 = 30000
34+
ytick_gap = 4000
35+
y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)]
3336

3437
ax[0].set_ylabel('Count per cochlea', fontsize=main_label_size)
3538
ax[0].set_yticks(y_ticks)
@@ -40,19 +43,18 @@ def fig_06a(save_path, plot=False):
4043
xmin = 0.5
4144
xmax = 1.5
4245
ax[0].set_xlim(xmin, xmax)
43-
upper_y = 15000
44-
lower_y = 13000
46+
lower_y, upper_y = literature_reference_values_gerbil("SGN")
4547
ax[0].hlines([lower_y, upper_y], xmin, xmax)
46-
ax[0].text(1, upper_y + 100, "literature reference (WIP)", color='C0', fontsize=main_tick_size, ha="center")
48+
ax[0].text(1, upper_y + 100, "literature reference", color='C0', fontsize=main_tick_size, ha="center")
4749
ax[0].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True)
4850

49-
ylim0 = 800
51+
ylim0 = 900
5052
ylim1 = 1400
51-
y_ticks = [i for i in range(ylim0, ylim1 + 1, 100)]
53+
ytick_gap = 200
54+
y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)]
5255

5356
ax[1].set_xticklabels(["IHC"], fontsize=main_label_size)
5457

55-
ax[1].set_ylabel('Count per cochlea', fontsize=main_label_size)
5658
ax[1].set_yticks(y_ticks)
5759
ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size)
5860
ax[1].set_ylim(ylim0, ylim1)
@@ -61,10 +63,9 @@ def fig_06a(save_path, plot=False):
6163
xmin = 0.5
6264
xmax = 1.5
6365
ax[1].set_xlim(xmin, xmax)
64-
upper_y = 1200
65-
lower_y = 1000
66+
lower_y, upper_y = literature_reference_values_gerbil("IHC")
6667
ax[1].hlines([lower_y, upper_y], xmin, xmax)
67-
ax[1].text(1, upper_y + 10, "literature reference (WIP)", color='C0', fontsize=main_tick_size, ha="center")
68+
ax[1].text(1, upper_y + 10, "literature reference", color='C0', fontsize=main_tick_size, ha="center")
6869
ax[1].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True)
6970

7071
plt.tight_layout()

scripts/figures/util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,15 @@ def literature_reference_values(structure):
7070
else:
7171
raise ValueError
7272
return lower_bound, upper_bound
73+
74+
75+
def literature_reference_values_gerbil(structure):
76+
if structure == "SGN":
77+
lower_bound, upper_bound = 24700, 28450
78+
elif structure == "IHC":
79+
lower_bound, upper_bound = 1081, 1081
80+
elif structure == "synapse":
81+
lower_bound, upper_bound = 9.1, 20.7
82+
else:
83+
raise ValueError
84+
return lower_bound, upper_bound

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def main():
3737
in which case the boundary distances are not used for the seeds.")
3838
parser.add_argument("--fg_threshold", default=0.5, type=float,
3939
help="The threshold applied to the foreground prediction for deriving the watershed mask.")
40+
parser.add_argument("--distance_smoothing", default=0, type=float,
41+
help="The sigma value for smoothing the distance predictions with a gaussian kernel.")
4042

4143
args = parser.parse_args()
4244

@@ -78,7 +80,7 @@ def main():
7880
seg_class=args.seg_class,
7981
center_distance_threshold=args.center_distance_threshold,
8082
boundary_distance_threshold=args.boundary_distance_threshold,
81-
fg_threshold=args.fg_threshold,
83+
fg_threshold=args.fg_threshold, distance_smoothing=args.distance_smoothing,
8284
)
8385

8486
abs_path = os.path.abspath(args.input)
@@ -95,7 +97,7 @@ def main():
9597
seg_class=args.seg_class,
9698
center_distance_threshold=args.center_distance_threshold,
9799
boundary_distance_threshold=args.boundary_distance_threshold,
98-
fg_threshold=args.fg_threshold,
100+
fg_threshold=args.fg_threshold, distance_smoothing=args.distance_smoothing,
99101
)
100102
timer_output = os.path.join(args.output_folder, "timer.json")
101103

0 commit comments

Comments
 (0)