Skip to content

Commit 373c6bd

Browse files
committed
Legends disconnected from plots
1 parent e31001e commit 373c6bd

File tree

5 files changed

+400
-171
lines changed

5 files changed

+400
-171
lines changed

scripts/figures/plot_fig2.py

Lines changed: 99 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@
44
import numpy as np
55
import pandas as pd
66
import matplotlib.pyplot as plt
7+
import matplotlib.ticker as mticker
8+
from matplotlib.lines import Line2D
79
import tifffile
810
from matplotlib import colors
911
from skimage.segmentation import find_boundaries
1012

1113
from util import literature_reference_values, SYNAPSE_DIR_ROOT
12-
from util import prism_style, prism_cleanup_axes
14+
from util import prism_style, prism_cleanup_axes, export_legend, custom_formatter_2
1315

1416
png_dpi = 300
1517
FILE_EXTENSION = "png"
1618

19+
COLOR_P = "#9C5027"
20+
COLOR_R = "#67279C"
21+
COLOR_F = "#9C276F"
22+
1723

1824
def scramble_instance_labels(arr):
1925
"""Scramble indexes of instance segmentation to avoid neighboring colors.
@@ -118,64 +124,65 @@ def fig_02b_ihc(save_dir, plot=False):
118124

119125
plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)
120126

127+
121128
def supp_fig_02(save_path, plot=False, segm="SGN"):
122129
# SGN
123130
value_dict = {
124131
"SGN": {
125-
"distance_unet" : {
126-
"label" : "CochleaNet",
132+
"distance_unet": {
133+
"label": "CochleaNet",
127134
"precision": 0.886,
128-
"recall" : 0.804,
135+
"recall": 0.804,
129136
"f1-score": 0.837
130137
},
131-
"micro_sam" : {
132-
"label" : "µSAM",
138+
"micro_sam": {
139+
"label": "µSAM",
133140
"precision": 0.140,
134-
"recall" : 0.782,
141+
"recall": 0.782,
135142
"f1-score": 0.228
136143
},
137-
"cellpose_sam" : {
138-
"label" : "Cellpose-SAM",
144+
"cellpose_sam": {
145+
"label": "Cellpose-SAM",
139146
"precision": 0.250,
140-
"recall" : 0.003,
147+
"recall": 0.003,
141148
"f1-score": 0.005
142149
},
143-
"cellpose_3" : {
144-
"label" : "Cellpose 3",
150+
"cellpose_3": {
151+
"label": "Cellpose 3",
145152
"precision": 0.117,
146-
"recall" : 0.607,
153+
"recall": 0.607,
147154
"f1-score": 0.186
148155
},
149-
"stardist" : {
150-
"label" : "Stardist",
156+
"stardist": {
157+
"label": "Stardist",
151158
"precision": 0.706,
152-
"recall" : 0.630,
159+
"recall": 0.630,
153160
"f1-score": 0.628
154161
},
155162
},
156163
"IHC": {
157-
"distance_unet" : {
158-
"label" : "CochleaNet",
164+
"distance_unet": {
165+
"label": "CochleaNet",
159166
"precision": 0.664,
160-
"recall" : 0.661,
167+
"recall": 0.661,
161168
"f1-score": 0.659
162169
},
163-
"micro_sam" : {
164-
"label" : "µSAM",
170+
"micro_sam": {
171+
"label": "µSAM",
165172
"precision": 0.053,
166-
"recall" : 0.684,
173+
"recall": 0.684,
167174
"f1-score": 0.094
168175
},
169-
"cellpose_sam" : {
170-
"label" : "Cellpose-SAM",
176+
"cellpose_sam": {
177+
"label": "Cellpose-SAM",
171178
"precision": 0.636,
172-
"recall" : 0.025,
179+
"recall": 0.025,
173180
"f1-score": 0.047
174181
},
175-
"cellpose_3" : {
176-
"label" : "Cellpose 3",
182+
"cellpose_3": {
183+
"label": "Cellpose 3",
177184
"precision": 0.375,
178-
"recall" : 0.554,
185+
"recall": 0.554,
179186
"f1-score": 0.329
180187
},
181188
}
@@ -190,9 +197,6 @@ def supp_fig_02(save_path, plot=False, segm="SGN"):
190197
x_pos = np.array([i * 2 for i in range(len(precision))])
191198

192199
# Convert setting labels to numerical x positions
193-
x = np.array([0.8, 1.2, 1.8, 2.2, 3])
194-
x_manual = np.array([0.8, 1.8])
195-
x_automatic = np.array([1.2, 2.2, 3])
196200
offset = 0.08 # horizontal shift for scatter separation
197201

198202
# Plot
@@ -201,21 +205,17 @@ def supp_fig_02(save_path, plot=False, segm="SGN"):
201205
main_label_size = 22
202206
main_tick_size = 16
203207

204-
color_p = "#3AA67E"
205-
color_r = "#438CA7"
206-
color_f = "#694BA6"
207-
208-
plt.scatter(x_pos - offset, precision, label="Precision", color=color_p, marker="^", s=80)
209-
plt.scatter(x_pos, recall, label="Recall", color=color_r, marker="o", s=80)
210-
plt.scatter(x_pos + offset, f1, label="F1-score manual", color=color_f, marker="s", s=80)
208+
plt.scatter(x_pos - offset, precision, label="Precision", color=COLOR_P, marker="^", s=80)
209+
plt.scatter(x_pos, recall, label="Recall", color=COLOR_R, marker="o", s=80)
210+
plt.scatter(x_pos + offset, f1, label="F1-score manual", color=COLOR_F, marker="s", s=80)
211211

212212
# Labels and formatting
213213
plt.xticks(x_pos, labels, fontsize=16)
214214
plt.yticks(fontsize=main_tick_size)
215215
plt.ylabel("Value", fontsize=main_label_size)
216216
plt.ylim(-0.1, 1)
217217
# plt.legend(loc="lower right", fontsize=legendsize)
218-
plt.grid(axis="y", linestyle="--", alpha=0.5)
218+
plt.grid(axis="y", linestyle="solid", alpha=0.5)
219219

220220
plt.tight_layout()
221221
prism_cleanup_axes(ax)
@@ -231,6 +231,44 @@ def supp_fig_02(save_path, plot=False, segm="SGN"):
231231
plt.close()
232232

233233

234+
def plot_legend_fig02c(figure_dir):
235+
"""Plot common legend for figure 2c.
236+
237+
Args:
238+
chreef_data: Data of ChReef cochleae.
239+
save_path: save path to save legend.
240+
grouping: Grouping for cochleae.
241+
"side_mono" for division in Injected and Non-Injected.
242+
"side_multi" for division per cochlea.
243+
"animal" for division per animal.
244+
use_alias: Use alias.
245+
"""
246+
save_path_shapes = os.path.join(figure_dir, f"fig_02c_legend_shapes.{FILE_EXTENSION}")
247+
save_path_colors = os.path.join(figure_dir, f"fig_02c_legend_colors.{FILE_EXTENSION}")
248+
249+
# Shapes
250+
color = ["black", "black"]
251+
marker = ["o", "s"]
252+
label = ["Manual", "Automatic"]
253+
254+
f = lambda m, c: plt.plot([], [], marker=m, color=c, ls="none")[0]
255+
handles = [f(m, c) for (c, m) in zip(color, marker)]
256+
legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False)
257+
export_legend(legend, save_path_shapes)
258+
legend.remove()
259+
plt.close()
260+
261+
# Colors
262+
color = [COLOR_P, COLOR_R, COLOR_F]
263+
label = ["Precision", "Recall", "F1-score"]
264+
265+
fl = lambda c: Line2D([], [], lw=3, color=c)
266+
handles = [fl(c) for c in color]
267+
legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False)
268+
export_legend(legend, save_path_colors)
269+
legend.remove()
270+
plt.close()
271+
234272

235273
def fig_02c(save_path, plot=False, all_versions=False):
236274
"""Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual),
@@ -261,44 +299,33 @@ def fig_02c(save_path, plot=False, all_versions=False):
261299
recall_automatic = [i[1] for i in automatic]
262300
f1score_automatic = [i[2] for i in automatic]
263301

264-
descr_y = 0.72
265-
266302
# Convert setting labels to numerical x positions
267-
x = np.array([0.8, 1.2, 1.8, 2.2, 3])
268303
x_manual = np.array([0.8, 1.8])
269304
x_automatic = np.array([1.2, 2.2, 3])
270305
offset = 0.08 # horizontal shift for scatter separation
271306

272307
# Plot
273-
fig, ax = plt.subplots(figsize=(8, 5))
308+
fig, ax = plt.subplots(figsize=(8, 4.5))
274309

275-
main_label_size = 22
276-
sub_label_size = 16
310+
main_label_size = 20
277311
main_tick_size = 16
278-
legendsize = 18
279312

280-
color_pm = "#3AA67E"
281-
color_pa = "#17E69A"
282-
color_rm = "#438CA7"
283-
color_ra = "#17AEE6"
284-
color_fm = "#694BA6"
285-
color_fa = "#6322E6"
313+
plt.scatter(x_manual - offset, precision_manual, label="Precision manual", color=COLOR_P, marker="o", s=80)
314+
plt.scatter(x_manual, recall_manual, label="Recall manual", color=COLOR_R, marker="o", s=80)
315+
plt.scatter(x_manual + offset, f1score_manual, label="F1-score manual", color=COLOR_F, marker="o", s=80)
286316

287-
plt.scatter(x_manual - offset, precision_manual, label="Precision manual", color=color_pm, marker="o", s=80)
288-
plt.scatter(x_manual, recall_manual, label="Recall manual", color=color_rm, marker="o", s=80)
289-
plt.scatter(x_manual + offset, f1score_manual, label="F1-score manual", color=color_fm, marker="o", s=80)
290-
291-
plt.scatter(x_automatic - offset, precision_automatic, label="Precision automatic", color=color_pa, marker="s", s=80)
292-
plt.scatter(x_automatic, recall_automatic, label="Recall automatic", color=color_ra, marker="s", s=80)
293-
plt.scatter(x_automatic + offset, f1score_automatic, label="F1-score automatic", color=color_fa, marker="s", s=80)
317+
plt.scatter(x_automatic - offset, precision_automatic, label="Precision automatic", color=COLOR_P, marker="s", s=80)
318+
plt.scatter(x_automatic, recall_automatic, label="Recall automatic", color=COLOR_R, marker="s", s=80)
319+
plt.scatter(x_automatic + offset, f1score_automatic, label="F1-score automatic", color=COLOR_F, marker="s", s=80)
294320

295321
# Labels and formatting
296-
plt.xticks([1,2,3], setting, fontsize=main_label_size)
322+
plt.xticks([1, 2, 3], setting, fontsize=main_label_size)
297323
plt.yticks(fontsize=main_tick_size)
324+
ax.yaxis.set_major_formatter(mticker.FuncFormatter(custom_formatter_2))
298325
plt.ylabel("Value", fontsize=main_label_size)
299326
plt.ylim(0.76, 1)
300327
# plt.legend(loc="lower right", fontsize=legendsize)
301-
plt.grid(axis="y", linestyle="--", alpha=0.5)
328+
plt.grid(axis="y", linestyle="solid", alpha=0.5)
302329

303330
plt.tight_layout()
304331
prism_cleanup_axes(ax)
@@ -329,27 +356,21 @@ def _load_ribbon_synapse_counts():
329356
def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_synapses=False):
330357
"""Box plot showing the counts for SGN and IHC per (mouse) cochlea in comparison to literature values.
331358
"""
332-
main_tick_size = 20
333-
main_label_size = 26
359+
prism_style()
360+
main_tick_size = 16
361+
main_label_size = 20
334362

335363
rows = 1
336364
columns = 3 if plot_average_ribbon_synapses else 2
337365

338366
sgn_values = [11153, 11398, 10333, 11820]
339-
ihc_v4b_values = [836, 808, 796, 901]
340367
ihc_v4c_values = [712, 710, 721, 675]
341-
ihc_v4c_filtered_values = [562, 647, 626, 628]
342368

343-
if all_versions:
344-
ihc_list = [ihc_v4b_values, ihc_v4c_values, ihc_v4c_filtered_values]
345-
suffixes = ["_v4b", "_v4c", "_v4c_filtered"]
346-
assert not plot_average_ribbon_synapses
347-
else:
348-
ihc_list = [ihc_v4c_values]
349-
suffixes = ["_v4c"]
369+
ihc_list = [ihc_v4c_values]
370+
suffixes = ["_v4c"]
350371

351372
for (ihc_values, suffix) in zip(ihc_list, suffixes):
352-
fig, axes = plt.subplots(rows, columns, figsize=(columns*4, rows*4))
373+
fig, axes = plt.subplots(rows, columns, figsize=(10, 4.5))
353374
ax = axes.flatten()
354375

355376
save_path_new = save_path.split(".")[0] + suffix + "." + save_path.split(".")[1]
@@ -376,7 +397,7 @@ def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_sy
376397
lower_y, upper_y = literature_reference_values("SGN")
377398
ax[0].hlines([lower_y, upper_y], xmin, xmax)
378399
ax[0].text(1., lower_y + (upper_y - lower_y) * 0.2, "literature",
379-
color="C0", fontsize=main_tick_size, ha="center")
400+
color="C0", fontsize=main_label_size, ha="center")
380401
ax[0].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True)
381402

382403
ylim0 = 600
@@ -407,7 +428,7 @@ def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_sy
407428
y_ticks = [0, 10, 20, 30, 40, 50]
408429

409430
ax[2].boxplot(ribbon_synapse_counts)
410-
ax[2].set_xticklabels(["Ribbon Syn. per IHC"], fontsize=main_label_size)
431+
ax[2].set_xticklabels(["Synapses per IHC"], fontsize=main_label_size)
411432
ax[2].set_yticks(y_ticks)
412433
ax[2].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size)
413434
ax[2].set_ylim(ylim0, ylim1)
@@ -421,6 +442,7 @@ def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_sy
421442
# ax[2].text(1.1, (lower_y + upper_y) // 2, "literature", color="C0", fontsize=main_tick_size, ha="left")
422443
ax[2].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True)
423444

445+
prism_cleanup_axes(axes)
424446
plt.tight_layout()
425447

426448
if ".png" in save_path:
@@ -501,7 +523,7 @@ def fig_02d_02(save_path, filter_zeros=True, plot=False):
501523

502524
plt.title("Average Synapses per IHC for a Dataset of 4 Cochleae")
503525

504-
plt.grid(axis="y", linestyle="--", alpha=0.5)
526+
plt.grid(axis="y", linestyle="solid", alpha=0.5)
505527
plt.legend(fontsize=legendsize)
506528
plt.tight_layout()
507529

@@ -530,6 +552,7 @@ def main():
530552

531553
# Panel C: Evaluation of the segmentation results:
532554
fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot, all_versions=False)
555+
plot_legend_fig02c(figure_dir=args.figure_dir)
533556

534557
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn.{FILE_EXTENSION}"), segm="SGN")
535558
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc.{FILE_EXTENSION}"), segm="IHC")

0 commit comments

Comments
 (0)