| 
19 | 19 | COLOR_P = "#9C5027"  | 
20 | 20 | COLOR_R = "#67279C"  | 
21 | 21 | COLOR_F = "#9C276F"  | 
 | 22 | +COLOR_T = "#279C52"  | 
22 | 23 | 
 
  | 
23 | 24 | 
 
  | 
24 | 25 | def scramble_instance_labels(arr):  | 
@@ -125,97 +126,176 @@ def fig_02b_ihc(save_dir, plot=False):  | 
125 | 126 |         plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)  | 
126 | 127 | 
 
  | 
127 | 128 | 
 
  | 
128 |  | -def supp_fig_02(save_path, plot=False, segm="SGN"):  | 
 | 129 | +def plot_legend_suppfig02(figure_dir):  | 
 | 130 | +    """Plot common legend for figure 2c.  | 
 | 131 | +
  | 
 | 132 | +    Args:  | 
 | 133 | +        chreef_data: Data of ChReef cochleae.  | 
 | 134 | +        save_path: save path to save legend.  | 
 | 135 | +        grouping: Grouping for cochleae.  | 
 | 136 | +            "side_mono" for division in Injected and Non-Injected.  | 
 | 137 | +            "side_multi" for division per cochlea.  | 
 | 138 | +            "animal" for division per animal.  | 
 | 139 | +        use_alias: Use alias.  | 
 | 140 | +    """  | 
 | 141 | +    save_path_colors = os.path.join(figure_dir, f"suppfig02_legend_colors.{FILE_EXTENSION}")  | 
 | 142 | +    # Colors  | 
 | 143 | +    color = [COLOR_P, COLOR_R, COLOR_F, COLOR_T]  | 
 | 144 | +    label = ["Precision", "Recall", "F1-score", "Processing time"]  | 
 | 145 | + | 
 | 146 | +    fl = lambda c: Line2D([], [], lw=3, color=c)  | 
 | 147 | +    handles = [fl(c) for c in color]  | 
 | 148 | +    legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False)  | 
 | 149 | +    export_legend(legend, save_path_colors)  | 
 | 150 | +    legend.remove()  | 
 | 151 | +    plt.close()  | 
 | 152 | + | 
 | 153 | + | 
 | 154 | +def supp_fig_02(save_path, plot=False, segm="SGN", mode="precision"):  | 
129 | 155 |     # SGN  | 
130 | 156 |     value_dict = {  | 
131 | 157 |         "SGN": {  | 
132 |  | -            "distance_unet": {  | 
133 |  | -                "label": "CochleaNet",  | 
134 |  | -                "precision": 0.886,  | 
135 |  | -                "recall": 0.804,  | 
136 |  | -                "f1-score": 0.837  | 
 | 158 | +            "stardist": {  | 
 | 159 | +                "label": "Stardist",  | 
 | 160 | +                "precision": 0.706,  | 
 | 161 | +                "recall": 0.630,  | 
 | 162 | +                "f1-score": 0.628,  | 
 | 163 | +                "marker": "o",  | 
 | 164 | +                "runtime": 536.5,  | 
 | 165 | +                "runtime_std": 148.4  | 
 | 166 | + | 
137 | 167 |             },  | 
138 | 168 |             "micro_sam": {  | 
139 | 169 |                 "label": "µSAM",  | 
140 | 170 |                 "precision": 0.140,  | 
141 | 171 |                 "recall": 0.782,  | 
142 |  | -                "f1-score": 0.228  | 
143 |  | -            },  | 
144 |  | -            "cellpose_sam": {  | 
145 |  | -                "label": "Cellpose-SAM",  | 
146 |  | -                "precision": 0.250,  | 
147 |  | -                "recall": 0.003,  | 
148 |  | -                "f1-score": 0.005  | 
 | 172 | +                "f1-score": 0.228,  | 
 | 173 | +                "marker": "D",  | 
 | 174 | +                "runtime": 407.5,  | 
 | 175 | +                "runtime_std": 107.5  | 
149 | 176 |             },  | 
150 | 177 |             "cellpose_3": {  | 
151 | 178 |                 "label": "Cellpose 3",  | 
152 | 179 |                 "precision": 0.117,  | 
153 | 180 |                 "recall": 0.607,  | 
154 |  | -                "f1-score": 0.186  | 
 | 181 | +                "f1-score": 0.186,  | 
 | 182 | +                "marker": "v",  | 
 | 183 | +                "runtime": None,  | 
 | 184 | +                "runtime_std": None  | 
155 | 185 |             },  | 
156 |  | -            "stardist": {  | 
157 |  | -                "label": "Stardist",  | 
158 |  | -                "precision": 0.706,  | 
159 |  | -                "recall": 0.630,  | 
160 |  | -                "f1-score": 0.628  | 
 | 186 | +            "cellpose_sam": {  | 
 | 187 | +                "label": "Cellpose-SAM",  | 
 | 188 | +                "precision": 0.250,  | 
 | 189 | +                "recall": 0.003,  | 
 | 190 | +                "f1-score": 0.005,  | 
 | 191 | +                "marker": "^",  | 
 | 192 | +                "runtime": 167.9,  | 
 | 193 | +                "runtime_std": 40.2  | 
161 | 194 |             },  | 
162 |  | -        },  | 
163 |  | -        "IHC": {  | 
164 | 195 |             "distance_unet": {  | 
165 | 196 |                 "label": "CochleaNet",  | 
166 |  | -                "precision": 0.664,  | 
167 |  | -                "recall": 	0.661,  | 
168 |  | -                "f1-score": 0.659  | 
 | 197 | +                "precision": 0.886,  | 
 | 198 | +                "recall": 0.804,  | 
 | 199 | +                "f1-score": 0.837,  | 
 | 200 | +                "marker": "s",  | 
 | 201 | +                "runtime": 168.8,  | 
 | 202 | +                "runtime_std": 21.8  | 
169 | 203 |             },  | 
 | 204 | +        },  | 
 | 205 | +        "IHC": {  | 
170 | 206 |             "micro_sam": {  | 
171 | 207 |                 "label": "µSAM",  | 
172 | 208 |                 "precision": 0.053,  | 
173 | 209 |                 "recall": 0.684,  | 
174 |  | -                "f1-score": 0.094  | 
 | 210 | +                "f1-score": 0.094,  | 
 | 211 | +                "marker": "D",  | 
 | 212 | +                "runtime": 445.6,  | 
 | 213 | +                "runtime_std": 106.6  | 
 | 214 | +            },  | 
 | 215 | +            "cellpose_3": {  | 
 | 216 | +                "label": "Cellpose 3",  | 
 | 217 | +                "precision": 0.375,  | 
 | 218 | +                "recall": 0.554,  | 
 | 219 | +                "f1-score": 0.329,  | 
 | 220 | +                "marker": "v",  | 
 | 221 | +                "runtime": 30.1,  | 
 | 222 | +                "runtime_std": 162.3  | 
175 | 223 |             },  | 
176 | 224 |             "cellpose_sam": {  | 
177 | 225 |                 "label": "Cellpose-SAM",  | 
178 | 226 |                 "precision": 0.636,  | 
179 | 227 |                 "recall": 0.025,  | 
180 |  | -                "f1-score": 0.047  | 
 | 228 | +                "f1-score": 0.047,  | 
 | 229 | +                "marker": "^",  | 
 | 230 | +                "runtime": None,  | 
 | 231 | +                "runtime_std": None  | 
181 | 232 |             },  | 
182 |  | -            "cellpose_3": {  | 
183 |  | -                "label": "Cellpose 3",  | 
184 |  | -                "precision": 0.375,  | 
185 |  | -                "recall": 0.554,  | 
186 |  | -                "f1-score": 0.329  | 
 | 233 | +            "distance_unet": {  | 
 | 234 | +                "label": "CochleaNet",  | 
 | 235 | +                "precision": 0.664,  | 
 | 236 | +                "recall": 	0.661,  | 
 | 237 | +                "f1-score": 0.659,  | 
 | 238 | +                "marker": "s",  | 
 | 239 | +                "runtime": 65.7,  | 
 | 240 | +                "runtime_std": 72.6  | 
187 | 241 |             },  | 
188 | 242 |         }  | 
189 | 243 |     }  | 
190 | 244 | 
 
  | 
191 |  | -    # SGN  | 
192 |  | -    precision = [value_dict[segm][key]["precision"] for key in value_dict[segm].keys()]  | 
193 |  | -    recall = [value_dict[segm][key]["recall"] for key in value_dict[segm].keys()]  | 
194 |  | -    f1 = [value_dict[segm][key]["f1-score"] for key in value_dict[segm].keys()]  | 
195 |  | -    labels = [value_dict[segm][key]["label"] for key in value_dict[segm].keys()]  | 
196 |  | - | 
197 |  | -    x_pos = np.array([i * 2 for i in range(len(precision))])  | 
198 |  | - | 
199 | 245 |     # Convert setting labels to numerical x positions  | 
200 | 246 |     offset = 0.08  # horizontal shift for scatter separation  | 
201 | 247 | 
 
  | 
202 | 248 |     # Plot  | 
203 | 249 |     fig, ax = plt.subplots(figsize=(8, 5))  | 
204 | 250 | 
 
  | 
205 |  | -    main_label_size = 22  | 
 | 251 | +    main_label_size = 20  | 
206 | 252 |     main_tick_size = 16  | 
207 | 253 | 
 
  | 
208 |  | -    plt.scatter(x_pos - offset, precision, label="Precision", color=COLOR_P, marker="^", s=80)  | 
209 |  | -    plt.scatter(x_pos,          recall, label="Recall", color=COLOR_R, marker="o", s=80)  | 
210 |  | -    plt.scatter(x_pos + offset, f1, label="F1-score manual", color=COLOR_F, marker="s", s=80)  | 
 | 254 | +    labels = [value_dict[segm][key]["label"] for key in value_dict[segm].keys()]  | 
211 | 255 | 
 
  | 
212 |  | -    # Labels and formatting  | 
213 |  | -    plt.xticks(x_pos, labels, fontsize=16)  | 
214 |  | -    plt.yticks(fontsize=main_tick_size)  | 
215 |  | -    plt.ylabel("Value", fontsize=main_label_size)  | 
216 |  | -    plt.ylim(-0.1, 1)  | 
217 |  | -    # plt.legend(loc="lower right", fontsize=legendsize)  | 
218 |  | -    plt.grid(axis="y", linestyle="solid", alpha=0.5)  | 
 | 256 | +    if mode == "precision":  | 
 | 257 | +        # Convert setting labels to numerical x positions  | 
 | 258 | +        offset = 0.08  # horizontal shift for scatter separation  | 
 | 259 | +        for num, key in enumerate(list(value_dict[segm].keys())):  | 
 | 260 | +            precision = [value_dict[segm][key]["precision"]]  | 
 | 261 | +            recall = [value_dict[segm][key]["recall"]]  | 
 | 262 | +            f1score = [value_dict[segm][key]["f1-score"]]  | 
 | 263 | +            marker = value_dict[segm][key]["marker"]  | 
 | 264 | +            x_pos = num + 1  | 
 | 265 | + | 
 | 266 | +            plt.scatter([x_pos - offset], precision, label="Precision manual", color=COLOR_P, marker=marker, s=80)  | 
 | 267 | +            plt.scatter([x_pos],         recall, label="Recall manual", color=COLOR_R, marker=marker, s=80)  | 
 | 268 | +            plt.scatter([x_pos + offset], f1score, label="F1-score manual", color=COLOR_F, marker=marker, s=80)  | 
 | 269 | + | 
 | 270 | +        # Labels and formatting  | 
 | 271 | +        x_pos = np.arange(1, len(labels)+1)  | 
 | 272 | +        print(x_pos)  | 
 | 273 | +        plt.xticks(x_pos, labels, fontsize=16)  | 
 | 274 | +        plt.yticks(fontsize=main_tick_size)  | 
 | 275 | +        plt.ylabel("Value", fontsize=main_label_size)  | 
 | 276 | +        plt.ylim(-0.1, 1)  | 
 | 277 | +        # plt.legend(loc="lower right", fontsize=legendsize)  | 
 | 278 | +        plt.grid(axis="y", linestyle="solid", alpha=0.5)  | 
 | 279 | + | 
 | 280 | +    elif mode == "runtime":  | 
 | 281 | +        # Convert setting labels to numerical x positions  | 
 | 282 | +        offset = 0.08  # horizontal shift for scatter separation  | 
 | 283 | +        for num, key in enumerate(list(value_dict[segm].keys())):  | 
 | 284 | +            runtime = [value_dict[segm][key]["runtime"]]  | 
 | 285 | +            marker = value_dict[segm][key]["marker"]  | 
 | 286 | +            x_pos = num + 1  | 
 | 287 | + | 
 | 288 | +            plt.scatter([x_pos], runtime, label="Runtime", color=COLOR_T, marker=marker, s=80)  | 
 | 289 | + | 
 | 290 | +        # Labels and formatting  | 
 | 291 | +        x_pos = np.arange(1, len(labels)+1)  | 
 | 292 | +        print(x_pos)  | 
 | 293 | +        plt.xticks(x_pos, labels, fontsize=16)  | 
 | 294 | +        plt.yticks(fontsize=main_tick_size)  | 
 | 295 | +        plt.ylabel("Processing time [s]", fontsize=main_label_size)  | 
 | 296 | +        plt.ylim(-0.1, 600)  | 
 | 297 | +        # plt.legend(loc="lower right", fontsize=legendsize)  | 
 | 298 | +        plt.grid(axis="y", linestyle="solid", alpha=0.5)  | 
219 | 299 | 
 
  | 
220 | 300 |     plt.tight_layout()  | 
221 | 301 |     prism_cleanup_axes(ax)  | 
@@ -553,10 +633,14 @@ def main():  | 
553 | 633 |     # Panel C: Evaluation of the segmentation results:  | 
554 | 634 |     fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot, all_versions=False)  | 
555 | 635 |     plot_legend_fig02c(figure_dir=args.figure_dir)  | 
 | 636 | +    plot_legend_suppfig02(figure_dir=args.figure_dir)  | 
556 | 637 | 
 
  | 
557 | 638 |     supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn.{FILE_EXTENSION}"), segm="SGN")  | 
558 | 639 |     supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc.{FILE_EXTENSION}"), segm="IHC")  | 
559 | 640 | 
 
  | 
 | 641 | +    supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn_time.{FILE_EXTENSION}"), segm="SGN", mode="runtime")  | 
 | 642 | +    supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc_time.{FILE_EXTENSION}"), segm="IHC", mode="runtime")  | 
 | 643 | + | 
560 | 644 |     # Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC  | 
561 | 645 |     fig_02d_01(save_path=os.path.join(args.figure_dir, f"fig_02d.{FILE_EXTENSION}"),  | 
562 | 646 |                plot=args.plot, plot_average_ribbon_synapses=True)  | 
 | 
0 commit comments