Skip to content

Commit c110dfe

Browse files
committed
Initial supplementary figure 02
1 parent e5c1d1a commit c110dfe

File tree

1 file changed

+134
-50
lines changed

1 file changed

+134
-50
lines changed

scripts/figures/plot_fig2.py

Lines changed: 134 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
COLOR_P = "#9C5027"
2020
COLOR_R = "#67279C"
2121
COLOR_F = "#9C276F"
22+
COLOR_T = "#279C52"
2223

2324

2425
def scramble_instance_labels(arr):
@@ -125,97 +126,176 @@ def fig_02b_ihc(save_dir, plot=False):
125126
plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)
126127

127128

128-
def supp_fig_02(save_path, plot=False, segm="SGN"):
129+
def plot_legend_suppfig02(figure_dir):
130+
"""Plot common legend for figure 2c.
131+
132+
Args:
133+
chreef_data: Data of ChReef cochleae.
134+
save_path: save path to save legend.
135+
grouping: Grouping for cochleae.
136+
"side_mono" for division in Injected and Non-Injected.
137+
"side_multi" for division per cochlea.
138+
"animal" for division per animal.
139+
use_alias: Use alias.
140+
"""
141+
save_path_colors = os.path.join(figure_dir, f"suppfig02_legend_colors.{FILE_EXTENSION}")
142+
# Colors
143+
color = [COLOR_P, COLOR_R, COLOR_F, COLOR_T]
144+
label = ["Precision", "Recall", "F1-score", "Processing time"]
145+
146+
fl = lambda c: Line2D([], [], lw=3, color=c)
147+
handles = [fl(c) for c in color]
148+
legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False)
149+
export_legend(legend, save_path_colors)
150+
legend.remove()
151+
plt.close()
152+
153+
154+
def supp_fig_02(save_path, plot=False, segm="SGN", mode="precision"):
129155
# SGN
130156
value_dict = {
131157
"SGN": {
132-
"distance_unet": {
133-
"label": "CochleaNet",
134-
"precision": 0.886,
135-
"recall": 0.804,
136-
"f1-score": 0.837
158+
"stardist": {
159+
"label": "Stardist",
160+
"precision": 0.706,
161+
"recall": 0.630,
162+
"f1-score": 0.628,
163+
"marker": "o",
164+
"runtime": 536.5,
165+
"runtime_std": 148.4
166+
137167
},
138168
"micro_sam": {
139169
"label": "µSAM",
140170
"precision": 0.140,
141171
"recall": 0.782,
142-
"f1-score": 0.228
143-
},
144-
"cellpose_sam": {
145-
"label": "Cellpose-SAM",
146-
"precision": 0.250,
147-
"recall": 0.003,
148-
"f1-score": 0.005
172+
"f1-score": 0.228,
173+
"marker": "D",
174+
"runtime": 407.5,
175+
"runtime_std": 107.5
149176
},
150177
"cellpose_3": {
151178
"label": "Cellpose 3",
152179
"precision": 0.117,
153180
"recall": 0.607,
154-
"f1-score": 0.186
181+
"f1-score": 0.186,
182+
"marker": "v",
183+
"runtime": None,
184+
"runtime_std": None
155185
},
156-
"stardist": {
157-
"label": "Stardist",
158-
"precision": 0.706,
159-
"recall": 0.630,
160-
"f1-score": 0.628
186+
"cellpose_sam": {
187+
"label": "Cellpose-SAM",
188+
"precision": 0.250,
189+
"recall": 0.003,
190+
"f1-score": 0.005,
191+
"marker": "^",
192+
"runtime": 167.9,
193+
"runtime_std": 40.2
161194
},
162-
},
163-
"IHC": {
164195
"distance_unet": {
165196
"label": "CochleaNet",
166-
"precision": 0.664,
167-
"recall": 0.661,
168-
"f1-score": 0.659
197+
"precision": 0.886,
198+
"recall": 0.804,
199+
"f1-score": 0.837,
200+
"marker": "s",
201+
"runtime": 168.8,
202+
"runtime_std": 21.8
169203
},
204+
},
205+
"IHC": {
170206
"micro_sam": {
171207
"label": "µSAM",
172208
"precision": 0.053,
173209
"recall": 0.684,
174-
"f1-score": 0.094
210+
"f1-score": 0.094,
211+
"marker": "D",
212+
"runtime": 445.6,
213+
"runtime_std": 106.6
214+
},
215+
"cellpose_3": {
216+
"label": "Cellpose 3",
217+
"precision": 0.375,
218+
"recall": 0.554,
219+
"f1-score": 0.329,
220+
"marker": "v",
221+
"runtime": 30.1,
222+
"runtime_std": 162.3
175223
},
176224
"cellpose_sam": {
177225
"label": "Cellpose-SAM",
178226
"precision": 0.636,
179227
"recall": 0.025,
180-
"f1-score": 0.047
228+
"f1-score": 0.047,
229+
"marker": "^",
230+
"runtime": None,
231+
"runtime_std": None
181232
},
182-
"cellpose_3": {
183-
"label": "Cellpose 3",
184-
"precision": 0.375,
185-
"recall": 0.554,
186-
"f1-score": 0.329
233+
"distance_unet": {
234+
"label": "CochleaNet",
235+
"precision": 0.664,
236+
"recall": 0.661,
237+
"f1-score": 0.659,
238+
"marker": "s",
239+
"runtime": 65.7,
240+
"runtime_std": 72.6
187241
},
188242
}
189243
}
190244

191-
# SGN
192-
precision = [value_dict[segm][key]["precision"] for key in value_dict[segm].keys()]
193-
recall = [value_dict[segm][key]["recall"] for key in value_dict[segm].keys()]
194-
f1 = [value_dict[segm][key]["f1-score"] for key in value_dict[segm].keys()]
195-
labels = [value_dict[segm][key]["label"] for key in value_dict[segm].keys()]
196-
197-
x_pos = np.array([i * 2 for i in range(len(precision))])
198-
199245
# Convert setting labels to numerical x positions
200246
offset = 0.08 # horizontal shift for scatter separation
201247

202248
# Plot
203249
fig, ax = plt.subplots(figsize=(8, 5))
204250

205-
main_label_size = 22
251+
main_label_size = 20
206252
main_tick_size = 16
207253

208-
plt.scatter(x_pos - offset, precision, label="Precision", color=COLOR_P, marker="^", s=80)
209-
plt.scatter(x_pos, recall, label="Recall", color=COLOR_R, marker="o", s=80)
210-
plt.scatter(x_pos + offset, f1, label="F1-score manual", color=COLOR_F, marker="s", s=80)
254+
labels = [value_dict[segm][key]["label"] for key in value_dict[segm].keys()]
211255

212-
# Labels and formatting
213-
plt.xticks(x_pos, labels, fontsize=16)
214-
plt.yticks(fontsize=main_tick_size)
215-
plt.ylabel("Value", fontsize=main_label_size)
216-
plt.ylim(-0.1, 1)
217-
# plt.legend(loc="lower right", fontsize=legendsize)
218-
plt.grid(axis="y", linestyle="solid", alpha=0.5)
256+
if mode == "precision":
257+
# Convert setting labels to numerical x positions
258+
offset = 0.08 # horizontal shift for scatter separation
259+
for num, key in enumerate(list(value_dict[segm].keys())):
260+
precision = [value_dict[segm][key]["precision"]]
261+
recall = [value_dict[segm][key]["recall"]]
262+
f1score = [value_dict[segm][key]["f1-score"]]
263+
marker = value_dict[segm][key]["marker"]
264+
x_pos = num + 1
265+
266+
plt.scatter([x_pos - offset], precision, label="Precision manual", color=COLOR_P, marker=marker, s=80)
267+
plt.scatter([x_pos], recall, label="Recall manual", color=COLOR_R, marker=marker, s=80)
268+
plt.scatter([x_pos + offset], f1score, label="F1-score manual", color=COLOR_F, marker=marker, s=80)
269+
270+
# Labels and formatting
271+
x_pos = np.arange(1, len(labels)+1)
272+
print(x_pos)
273+
plt.xticks(x_pos, labels, fontsize=16)
274+
plt.yticks(fontsize=main_tick_size)
275+
plt.ylabel("Value", fontsize=main_label_size)
276+
plt.ylim(-0.1, 1)
277+
# plt.legend(loc="lower right", fontsize=legendsize)
278+
plt.grid(axis="y", linestyle="solid", alpha=0.5)
279+
280+
elif mode == "runtime":
281+
# Convert setting labels to numerical x positions
282+
offset = 0.08 # horizontal shift for scatter separation
283+
for num, key in enumerate(list(value_dict[segm].keys())):
284+
runtime = [value_dict[segm][key]["runtime"]]
285+
marker = value_dict[segm][key]["marker"]
286+
x_pos = num + 1
287+
288+
plt.scatter([x_pos], runtime, label="Runtime", color=COLOR_T, marker=marker, s=80)
289+
290+
# Labels and formatting
291+
x_pos = np.arange(1, len(labels)+1)
292+
print(x_pos)
293+
plt.xticks(x_pos, labels, fontsize=16)
294+
plt.yticks(fontsize=main_tick_size)
295+
plt.ylabel("Processing time [s]", fontsize=main_label_size)
296+
plt.ylim(-0.1, 600)
297+
# plt.legend(loc="lower right", fontsize=legendsize)
298+
plt.grid(axis="y", linestyle="solid", alpha=0.5)
219299

220300
plt.tight_layout()
221301
prism_cleanup_axes(ax)
@@ -553,10 +633,14 @@ def main():
553633
# Panel C: Evaluation of the segmentation results:
554634
fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot, all_versions=False)
555635
plot_legend_fig02c(figure_dir=args.figure_dir)
636+
plot_legend_suppfig02(figure_dir=args.figure_dir)
556637

557638
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn.{FILE_EXTENSION}"), segm="SGN")
558639
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc.{FILE_EXTENSION}"), segm="IHC")
559640

641+
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn_time.{FILE_EXTENSION}"), segm="SGN", mode="runtime")
642+
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc_time.{FILE_EXTENSION}"), segm="IHC", mode="runtime")
643+
560644
# Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC
561645
fig_02d_01(save_path=os.path.join(args.figure_dir, f"fig_02d.{FILE_EXTENSION}"),
562646
plot=args.plot, plot_average_ribbon_synapses=True)

0 commit comments

Comments
 (0)