44import numpy as np
55import pandas as pd
66import matplotlib .pyplot as plt
7+ import matplotlib .ticker as mticker
8+ from matplotlib .lines import Line2D
79import tifffile
810from matplotlib import colors
911from skimage .segmentation import find_boundaries
1012
1113from util import literature_reference_values , SYNAPSE_DIR_ROOT
12- from util import prism_style , prism_cleanup_axes
14+ from util import prism_style , prism_cleanup_axes , export_legend , custom_formatter_2
1315
1416png_dpi = 300
1517FILE_EXTENSION = "png"
1618
19+ COLOR_P = "#9C5027"
20+ COLOR_R = "#67279C"
21+ COLOR_F = "#9C276F"
22+
1723
1824def scramble_instance_labels (arr ):
1925 """Scramble indexes of instance segmentation to avoid neighboring colors.
@@ -118,64 +124,65 @@ def fig_02b_ihc(save_dir, plot=False):
118124
119125 plot_seg_crop (img_path , seg_path , save_path , xlim1 , xlim2 , ylim1 , ylim2 , boundary_rgba , plot = plot )
120126
127+
121128def supp_fig_02 (save_path , plot = False , segm = "SGN" ):
122129 # SGN
123130 value_dict = {
124131 "SGN" : {
125- "distance_unet" : {
126- "label" : "CochleaNet" ,
132+ "distance_unet" : {
133+ "label" : "CochleaNet" ,
127134 "precision" : 0.886 ,
128- "recall" : 0.804 ,
135+ "recall" : 0.804 ,
129136 "f1-score" : 0.837
130137 },
131- "micro_sam" : {
132- "label" : "µSAM" ,
138+ "micro_sam" : {
139+ "label" : "µSAM" ,
133140 "precision" : 0.140 ,
134- "recall" : 0.782 ,
141+ "recall" : 0.782 ,
135142 "f1-score" : 0.228
136143 },
137- "cellpose_sam" : {
138- "label" : "Cellpose-SAM" ,
144+ "cellpose_sam" : {
145+ "label" : "Cellpose-SAM" ,
139146 "precision" : 0.250 ,
140- "recall" : 0.003 ,
147+ "recall" : 0.003 ,
141148 "f1-score" : 0.005
142149 },
143- "cellpose_3" : {
144- "label" : "Cellpose 3" ,
150+ "cellpose_3" : {
151+ "label" : "Cellpose 3" ,
145152 "precision" : 0.117 ,
146- "recall" : 0.607 ,
153+ "recall" : 0.607 ,
147154 "f1-score" : 0.186
148155 },
149- "stardist" : {
150- "label" : "Stardist" ,
156+ "stardist" : {
157+ "label" : "Stardist" ,
151158 "precision" : 0.706 ,
152- "recall" : 0.630 ,
159+ "recall" : 0.630 ,
153160 "f1-score" : 0.628
154161 },
155162 },
156163 "IHC" : {
157- "distance_unet" : {
158- "label" : "CochleaNet" ,
164+ "distance_unet" : {
165+ "label" : "CochleaNet" ,
159166 "precision" : 0.664 ,
160- "recall" : 0.661 ,
167+ "recall" : 0.661 ,
161168 "f1-score" : 0.659
162169 },
163- "micro_sam" : {
164- "label" : "µSAM" ,
170+ "micro_sam" : {
171+ "label" : "µSAM" ,
165172 "precision" : 0.053 ,
166- "recall" : 0.684 ,
173+ "recall" : 0.684 ,
167174 "f1-score" : 0.094
168175 },
169- "cellpose_sam" : {
170- "label" : "Cellpose-SAM" ,
176+ "cellpose_sam" : {
177+ "label" : "Cellpose-SAM" ,
171178 "precision" : 0.636 ,
172- "recall" : 0.025 ,
179+ "recall" : 0.025 ,
173180 "f1-score" : 0.047
174181 },
175- "cellpose_3" : {
176- "label" : "Cellpose 3" ,
182+ "cellpose_3" : {
183+ "label" : "Cellpose 3" ,
177184 "precision" : 0.375 ,
178- "recall" : 0.554 ,
185+ "recall" : 0.554 ,
179186 "f1-score" : 0.329
180187 },
181188 }
@@ -190,9 +197,6 @@ def supp_fig_02(save_path, plot=False, segm="SGN"):
190197 x_pos = np .array ([i * 2 for i in range (len (precision ))])
191198
192199 # Convert setting labels to numerical x positions
193- x = np .array ([0.8 , 1.2 , 1.8 , 2.2 , 3 ])
194- x_manual = np .array ([0.8 , 1.8 ])
195- x_automatic = np .array ([1.2 , 2.2 , 3 ])
196200 offset = 0.08 # horizontal shift for scatter separation
197201
198202 # Plot
@@ -201,21 +205,17 @@ def supp_fig_02(save_path, plot=False, segm="SGN"):
201205 main_label_size = 22
202206 main_tick_size = 16
203207
204- color_p = "#3AA67E"
205- color_r = "#438CA7"
206- color_f = "#694BA6"
207-
208- plt .scatter (x_pos - offset , precision , label = "Precision" , color = color_p , marker = "^" , s = 80 )
209- plt .scatter (x_pos , recall , label = "Recall" , color = color_r , marker = "o" , s = 80 )
210- plt .scatter (x_pos + offset , f1 , label = "F1-score manual" , color = color_f , marker = "s" , s = 80 )
208+ plt .scatter (x_pos - offset , precision , label = "Precision" , color = COLOR_P , marker = "^" , s = 80 )
209+ plt .scatter (x_pos , recall , label = "Recall" , color = COLOR_R , marker = "o" , s = 80 )
210+ plt .scatter (x_pos + offset , f1 , label = "F1-score manual" , color = COLOR_F , marker = "s" , s = 80 )
211211
212212 # Labels and formatting
213213 plt .xticks (x_pos , labels , fontsize = 16 )
214214 plt .yticks (fontsize = main_tick_size )
215215 plt .ylabel ("Value" , fontsize = main_label_size )
216216 plt .ylim (- 0.1 , 1 )
217217 # plt.legend(loc="lower right", fontsize=legendsize)
218- plt .grid (axis = "y" , linestyle = "-- " , alpha = 0.5 )
218+ plt .grid (axis = "y" , linestyle = "solid " , alpha = 0.5 )
219219
220220 plt .tight_layout ()
221221 prism_cleanup_axes (ax )
@@ -231,6 +231,44 @@ def supp_fig_02(save_path, plot=False, segm="SGN"):
231231 plt .close ()
232232
233233
234+ def plot_legend_fig02c (figure_dir ):
235+ """Plot common legend for figure 2c.
236+
237+ Args:
238+ chreef_data: Data of ChReef cochleae.
239+ save_path: save path to save legend.
240+ grouping: Grouping for cochleae.
241+ "side_mono" for division in Injected and Non-Injected.
242+ "side_multi" for division per cochlea.
243+ "animal" for division per animal.
244+ use_alias: Use alias.
245+ """
246+ save_path_shapes = os .path .join (figure_dir , f"fig_02c_legend_shapes.{ FILE_EXTENSION } " )
247+ save_path_colors = os .path .join (figure_dir , f"fig_02c_legend_colors.{ FILE_EXTENSION } " )
248+
249+ # Shapes
250+ color = ["black" , "black" ]
251+ marker = ["o" , "s" ]
252+ label = ["Manual" , "Automatic" ]
253+
254+ f = lambda m , c : plt .plot ([], [], marker = m , color = c , ls = "none" )[0 ]
255+ handles = [f (m , c ) for (c , m ) in zip (color , marker )]
256+ legend = plt .legend (handles , label , loc = 3 , ncol = len (label ), framealpha = 1 , frameon = False )
257+ export_legend (legend , save_path_shapes )
258+ legend .remove ()
259+ plt .close ()
260+
261+ # Colors
262+ color = [COLOR_P , COLOR_R , COLOR_F ]
263+ label = ["Precision" , "Recall" , "F1-score" ]
264+
265+ fl = lambda c : Line2D ([], [], lw = 3 , color = c )
266+ handles = [fl (c ) for c in color ]
267+ legend = plt .legend (handles , label , loc = 3 , ncol = len (label ), framealpha = 1 , frameon = False )
268+ export_legend (legend , save_path_colors )
269+ legend .remove ()
270+ plt .close ()
271+
234272
235273def fig_02c (save_path , plot = False , all_versions = False ):
236274 """Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual),
@@ -261,44 +299,33 @@ def fig_02c(save_path, plot=False, all_versions=False):
261299 recall_automatic = [i [1 ] for i in automatic ]
262300 f1score_automatic = [i [2 ] for i in automatic ]
263301
264- descr_y = 0.72
265-
266302 # Convert setting labels to numerical x positions
267- x = np .array ([0.8 , 1.2 , 1.8 , 2.2 , 3 ])
268303 x_manual = np .array ([0.8 , 1.8 ])
269304 x_automatic = np .array ([1.2 , 2.2 , 3 ])
270305 offset = 0.08 # horizontal shift for scatter separation
271306
272307 # Plot
273- fig , ax = plt .subplots (figsize = (8 , 5 ))
308+ fig , ax = plt .subplots (figsize = (8 , 4. 5 ))
274309
275- main_label_size = 22
276- sub_label_size = 16
310+ main_label_size = 20
277311 main_tick_size = 16
278- legendsize = 18
279312
280- color_pm = "#3AA67E"
281- color_pa = "#17E69A"
282- color_rm = "#438CA7"
283- color_ra = "#17AEE6"
284- color_fm = "#694BA6"
285- color_fa = "#6322E6"
313+ plt .scatter (x_manual - offset , precision_manual , label = "Precision manual" , color = COLOR_P , marker = "o" , s = 80 )
314+ plt .scatter (x_manual , recall_manual , label = "Recall manual" , color = COLOR_R , marker = "o" , s = 80 )
315+ plt .scatter (x_manual + offset , f1score_manual , label = "F1-score manual" , color = COLOR_F , marker = "o" , s = 80 )
286316
287- plt .scatter (x_manual - offset , precision_manual , label = "Precision manual" , color = color_pm , marker = "o" , s = 80 )
288- plt .scatter (x_manual , recall_manual , label = "Recall manual" , color = color_rm , marker = "o" , s = 80 )
289- plt .scatter (x_manual + offset , f1score_manual , label = "F1-score manual" , color = color_fm , marker = "o" , s = 80 )
290-
291- plt .scatter (x_automatic - offset , precision_automatic , label = "Precision automatic" , color = color_pa , marker = "s" , s = 80 )
292- plt .scatter (x_automatic , recall_automatic , label = "Recall automatic" , color = color_ra , marker = "s" , s = 80 )
293- plt .scatter (x_automatic + offset , f1score_automatic , label = "F1-score automatic" , color = color_fa , marker = "s" , s = 80 )
317+ plt .scatter (x_automatic - offset , precision_automatic , label = "Precision automatic" , color = COLOR_P , marker = "s" , s = 80 )
318+ plt .scatter (x_automatic , recall_automatic , label = "Recall automatic" , color = COLOR_R , marker = "s" , s = 80 )
319+ plt .scatter (x_automatic + offset , f1score_automatic , label = "F1-score automatic" , color = COLOR_F , marker = "s" , s = 80 )
294320
295321 # Labels and formatting
296- plt .xticks ([1 ,2 , 3 ], setting , fontsize = main_label_size )
322+ plt .xticks ([1 , 2 , 3 ], setting , fontsize = main_label_size )
297323 plt .yticks (fontsize = main_tick_size )
324+ ax .yaxis .set_major_formatter (mticker .FuncFormatter (custom_formatter_2 ))
298325 plt .ylabel ("Value" , fontsize = main_label_size )
299326 plt .ylim (0.76 , 1 )
300327 # plt.legend(loc="lower right", fontsize=legendsize)
301- plt .grid (axis = "y" , linestyle = "-- " , alpha = 0.5 )
328+ plt .grid (axis = "y" , linestyle = "solid " , alpha = 0.5 )
302329
303330 plt .tight_layout ()
304331 prism_cleanup_axes (ax )
@@ -329,27 +356,21 @@ def _load_ribbon_synapse_counts():
329356def fig_02d_01 (save_path , plot = False , all_versions = False , plot_average_ribbon_synapses = False ):
330357 """Box plot showing the counts for SGN and IHC per (mouse) cochlea in comparison to literature values.
331358 """
332- main_tick_size = 20
333- main_label_size = 26
359+ prism_style ()
360+ main_tick_size = 16
361+ main_label_size = 20
334362
335363 rows = 1
336364 columns = 3 if plot_average_ribbon_synapses else 2
337365
338366 sgn_values = [11153 , 11398 , 10333 , 11820 ]
339- ihc_v4b_values = [836 , 808 , 796 , 901 ]
340367 ihc_v4c_values = [712 , 710 , 721 , 675 ]
341- ihc_v4c_filtered_values = [562 , 647 , 626 , 628 ]
342368
343- if all_versions :
344- ihc_list = [ihc_v4b_values , ihc_v4c_values , ihc_v4c_filtered_values ]
345- suffixes = ["_v4b" , "_v4c" , "_v4c_filtered" ]
346- assert not plot_average_ribbon_synapses
347- else :
348- ihc_list = [ihc_v4c_values ]
349- suffixes = ["_v4c" ]
369+ ihc_list = [ihc_v4c_values ]
370+ suffixes = ["_v4c" ]
350371
351372 for (ihc_values , suffix ) in zip (ihc_list , suffixes ):
352- fig , axes = plt .subplots (rows , columns , figsize = (columns * 4 , rows * 4 ))
373+ fig , axes = plt .subplots (rows , columns , figsize = (10 , 4.5 ))
353374 ax = axes .flatten ()
354375
355376 save_path_new = save_path .split ("." )[0 ] + suffix + "." + save_path .split ("." )[1 ]
@@ -376,7 +397,7 @@ def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_sy
376397 lower_y , upper_y = literature_reference_values ("SGN" )
377398 ax [0 ].hlines ([lower_y , upper_y ], xmin , xmax )
378399 ax [0 ].text (1. , lower_y + (upper_y - lower_y ) * 0.2 , "literature" ,
379- color = "C0" , fontsize = main_tick_size , ha = "center" )
400+ color = "C0" , fontsize = main_label_size , ha = "center" )
380401 ax [0 ].fill_between ([xmin , xmax ], lower_y , upper_y , color = "C0" , alpha = 0.05 , interpolate = True )
381402
382403 ylim0 = 600
@@ -407,7 +428,7 @@ def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_sy
407428 y_ticks = [0 , 10 , 20 , 30 , 40 , 50 ]
408429
409430 ax [2 ].boxplot (ribbon_synapse_counts )
410- ax [2 ].set_xticklabels (["Ribbon Syn. per IHC" ], fontsize = main_label_size )
431+ ax [2 ].set_xticklabels (["Synapses per IHC" ], fontsize = main_label_size )
411432 ax [2 ].set_yticks (y_ticks )
412433 ax [2 ].set_yticklabels (y_ticks , rotation = 0 , fontsize = main_tick_size )
413434 ax [2 ].set_ylim (ylim0 , ylim1 )
@@ -421,6 +442,7 @@ def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_sy
421442 # ax[2].text(1.1, (lower_y + upper_y) // 2, "literature", color="C0", fontsize=main_tick_size, ha="left")
422443 ax [2 ].fill_between ([xmin , xmax ], lower_y , upper_y , color = "C0" , alpha = 0.05 , interpolate = True )
423444
445+ prism_cleanup_axes (axes )
424446 plt .tight_layout ()
425447
426448 if ".png" in save_path :
@@ -501,7 +523,7 @@ def fig_02d_02(save_path, filter_zeros=True, plot=False):
501523
502524 plt .title ("Average Synapses per IHC for a Dataset of 4 Cochleae" )
503525
504- plt .grid (axis = "y" , linestyle = "-- " , alpha = 0.5 )
526+ plt .grid (axis = "y" , linestyle = "solid " , alpha = 0.5 )
505527 plt .legend (fontsize = legendsize )
506528 plt .tight_layout ()
507529
@@ -530,6 +552,7 @@ def main():
530552
531553 # Panel C: Evaluation of the segmentation results:
532554 fig_02c (save_path = os .path .join (args .figure_dir , f"fig_02c.{ FILE_EXTENSION } " ), plot = args .plot , all_versions = False )
555+ plot_legend_fig02c (figure_dir = args .figure_dir )
533556
534557 supp_fig_02 (save_path = os .path .join (args .figure_dir , f"figsupp_02_sgn.{ FILE_EXTENSION } " ), segm = "SGN" )
535558 supp_fig_02 (save_path = os .path .join (args .figure_dir , f"figsupp_02_ihc.{ FILE_EXTENSION } " ), segm = "IHC" )
0 commit comments