Skip to content

Commit 75022cd

Browse files
committed
improv rotation/train csv + docs
added rotation toggle in metrics + added training csv + docs
1 parent 4b97a21 commit 75022cd

File tree

5 files changed

+76
-39
lines changed

5 files changed

+76
-39
lines changed

docs/res/guides/metrics_module_guide.rst

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,20 @@ The Dice coefficient is defined as :
1111

1212
It is a measure of similarity between two sets- :math:`0` indicating no similarity and :math:`1` complete similarity.
1313

14-
You will need to provide two folders : one for ground truth labels and one for predicition labels.
14+
You will need to provide the following parameters:
15+
16+
* Two folders : one for ground truth labels and one for prediction labels.
17+
18+
* The threshold below which the score is considered insufficient.
19+
Any pair below that score will be shown on the viewer; and be displayed in red in the plot.
20+
21+
* Whether to automatically determine the best orientation for the computation by rotating and flipping;
22+
use this if your images do not have the same orientation.
23+
24+
.. note::
25+
Due to changes in orientation of images after running inference, the utility can rotate and flip images randomly to find the best Dice coefficient
26+
to compensate. If you have small images with a very large number of labels, this can lead to an inexact metric being computed.
27+
Images with a low score might be in the wrong orientation as well when displayed for comparison.
1528

1629
.. important::
1730
This utility assumes that **predictions are padded to a power of two already.** Ground truth labels can be smaller,
@@ -22,10 +35,6 @@ Once you are ready, press the "Compute Dice" button. This will plot the Dice sco
2235
Pairs with a low score will be displayed on the viewer for checking, ground truth in **blue**, low score prediction in **red**.
2336

2437

25-
.. note::
26-
Due to changes in orientation of images after running inference, the utility will rotate and flip images to find the best Dice coefficient
27-
to compensate. If you have small images with a very large number of labels, this can lead to an inexact metric being computed.
28-
Images with a low score might be in the wrong orientation as well when displayed for comparison.
2938

3039

3140
Source code

docs/res/guides/review_module_guide.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Reviewer module guide
44
=================================
55

66
This module allows you to review your labels, from predictions or manual labeling,
7-
and correct them if needed. It then saves the status of each file in a csv, for easier monitoring.
7+
and correct them if needed. It then saves the status of each file in a csv, as well as the time taken per slice review, for easier monitoring.
88

99

1010

@@ -17,13 +17,13 @@ Launching the review process
1717
Folders can be stacks of either **.png** or **.tif** files, ideally numbered with the index of the slice at the end.
1818

1919
.. note::
20-
Only single .tif files or folder of several .png or .tif are supported.
20+
Only single .tif files or folder of several .png or .tif are currently supported.
2121

22-
* Model name :
22+
* Csv file name :
2323
You can then provide a model name, which will be used to name the csv file recording the status of each slice.
2424

2525
If a corresponding csv file exists already, it will be used. If not, a new one will be created.
26-
If you choose to create a new dataset, a new csv will be created no matter what,
26+
If you choose to create a new dataset, a new csv will always be created,
2727
with a trailing number if several copies of it already exists.
2828

2929
* Start :

docs/res/guides/training_module_guide.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ The training module is comprised of several tabs.
4444
* Whether to use pre-trained weights that are provided; if you choose to do so, the model will be initialized with the specified weights, possibly improving performance (transfer learning).
4545
You can also load custom weights; simply ensure they are compatible with the model.
4646

47+
* The proportion of the dataset to keep for training versus validation; if you have a large dataset, you can set it to a lower value to have more accurate validation steps.
48+
4749
2) The second tab, **Augmentation**, lets you define dataset and augmentation parameters such as :
4850

4951
* Whether to use images "as is" (**requires all images to be of the same size and cubic**) or extract patches.
@@ -53,6 +55,7 @@ The training module is comprised of several tabs.
5355
* The size of patches to be extracted (ideally, please use a value **close to a power of two**, such as 120 or 60 to ensure correct size.)
5456
* The number of samples to extract from each of your images. A larger number will likely mean better performances, but longer training and larger memory usage.
5557

58+
5659
* Whether to perform data augmentation or not (elastic deforms, intensity shifts. random flipping,etc). A rule of thumb for augmentation is :
5760

5861
* If you're using the patch extraction method, enable it if you are using more than 10 samples per image with at least 5 images
@@ -63,6 +66,7 @@ The training module is comprised of several tabs.
6366

6467
* The **model** to use for training (see table above)
6568
* The **loss function** used for training (see table below)
69+
* The **learning rate** of the optimizer. Setting it to a lower value if you're using pre-trained weights can improve performance.
6670
* The **batch size** (larger means quicker training and possibly better performance but increased memory usage)
6771
* The **number of epochs** (a possibility is to start with 60 epochs, and decrease or increase depending on performance.)
6872
* The **epoch interval** for validation (for example, if set to two, the module will use the validation dataset to evaluate the model with the dice metric every two epochs.)

napari_cellseg3d/plugin_metrics.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from napari_cellseg3d.model_instance_seg import to_semantic
1414
from napari_cellseg3d.plugin_base import BasePluginFolder
1515

16+
DEFAULT_THRESHOLD = 0.5
17+
1618

1719
class MetricsUtils(BasePluginFolder):
1820
"""Plugin to evaluate metrics between two sets of labels, ground truh and prediction"""
@@ -40,8 +42,13 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):
4042
self.btn_compute_dice = ui.make_button(
4143
"Compute Dice", self.compute_dice
4244
)
45+
46+
self.rotate_choice = ui.make_checkbox("Find best orientation")
47+
4348
self.btn_reset_plot = ui.make_button("Clear plots", self.remove_plots)
4449

50+
self.threshold_box = ui.make_n_spinboxes(min=0.1, max = 1, default=DEFAULT_THRESHOLD,step=0.1, double=True)
51+
4552
self.btn_result_path.setVisible(False)
4653
self.lbl_result_path.setVisible(False)
4754

@@ -72,7 +79,7 @@ def build(self):
7279

7380
w, self.layout = ui.make_container_widget()
7481

75-
metrics_group_w, metrics_group_l = ui.make_group("Metrics")
82+
metrics_group_w, metrics_group_l = ui.make_group("Data")
7683

7784
self.lbl_image_files.setText("Ground truth")
7885

@@ -96,11 +103,21 @@ def build(self):
96103
alignment=ui.LEFT_AL,
97104
)
98105

99-
metrics_group_l.addWidget(self.btn_compute_dice, alignment=ui.LEFT_AL)
100-
101106
metrics_group_w.setLayout(metrics_group_l)
107+
############################
108+
ui.add_blank(self, self.layout)
109+
############################
110+
param_group_w, param_group_l = ui.make_group("Parameters")
111+
112+
param_group_l.addWidget(self.threshold_box)
113+
param_group_l.addWidget(self.rotate_choice)
114+
115+
param_group_w.setLayout(param_group_l)
102116

103117
self.layout.addWidget(metrics_group_w)
118+
self.layout.addWidget(param_group_w)
119+
120+
self.layout.addWidget(self.btn_compute_dice, alignment=ui.LEFT_AL)
104121

105122
self.layout.addWidget(self.make_close_button(), alignment=ui.LEFT_AL)
106123

@@ -109,15 +126,15 @@ def build(self):
109126

110127
ui.make_scrollable(self.layout, self)
111128

112-
def plot_dice(self, dice_coeffs):
129+
def plot_dice(self, dice_coeffs, threshold=DEFAULT_THRESHOLD):
113130
"""Plots the dice loss for each pair of labels on viewer"""
114131
self.btn_reset_plot.setVisible(True)
115132
colors = []
116133

117134
bckgrd_color = (0, 0, 0, 0)
118135

119136
for coeff in dice_coeffs: # TODO add threshold manual setting
120-
if coeff < 0.5:
137+
if coeff < threshold:
121138
colors.append(ui.dark_red) # 72071d # crimson red
122139
else:
123140
colors.append(ui.default_cyan) # 8dd3c7 # turquoise cyan
@@ -138,7 +155,7 @@ def plot_dice(self, dice_coeffs):
138155
dice_plot.invert_yaxis()
139156

140157
self.plots.append(self.canvas)
141-
dice_plot.axvline(0.5, color=ui.dark_red)
158+
dice_plot.axvline(threshold, color=ui.dark_red)
142159
dice_plot.set_title(
143160
f"Session {len(self.plots)}\nMean dice : {np.mean(dice_coeffs):.4f}"
144161
)
@@ -157,12 +174,17 @@ def remove_plots(self):
157174
self.btn_reset_plot.setVisible(False)
158175

159176
def compute_dice(self):
160-
"""Computes the dice metric between pairs of labels. Rotates the prediction label to find matching orientation as well."""
177+
"""Computes the dice metric between pairs of labels.
178+
Rotates the prediction label to find matching orientation as well."""
161179
# u = 0
162180
# t = 0
181+
182+
threshold = self.threshold_box.value()
183+
rotate = self.rotate_choice.isChecked()
184+
163185
total_metrics = []
164186
self.canvas = (
165-
None # kind of unsafe way to stack plots... but it works.
187+
None # kind of terrible way to stack plots... but it works.
166188
)
167189
id = 0
168190
for ground_path, pred_path in zip(
@@ -230,18 +252,20 @@ def compute_dice(self):
230252
# )
231253
# u += 1
232254

233-
# TODO add rotation toggle
234-
pred_flip_x = np.rot90(pred[0][0], axes=(0, 1))
235-
pred_flip_y = np.rot90(pred[0][0], axes=(1, 2))
236-
pred_flip_z = np.rot90(pred[0][0], axes=(0, 2))
237255
scores = []
238-
239-
for p in [pred[0][0], pred_flip_x, pred_flip_y, pred_flip_z]:
240-
scores.append(utils.dice_coeff(p, ground))
241-
scores.append(utils.dice_coeff(np.flip(p), ground))
242-
for i in range(3):
243-
scores.append(utils.dice_coeff(np.flip(p, axis=i), ground))
244-
256+
if rotate: # TODO : recored best rotation for display
257+
pred_flip_x = np.rot90(pred[0][0], axes=(0, 1))
258+
pred_flip_y = np.rot90(pred[0][0], axes=(1, 2))
259+
pred_flip_z = np.rot90(pred[0][0], axes=(0, 2))
260+
261+
for p in [pred[0][0], pred_flip_x, pred_flip_y, pred_flip_z]:
262+
scores.append(utils.dice_coeff(p, ground))
263+
scores.append(utils.dice_coeff(np.flip(p), ground))
264+
for i in range(3):
265+
scores.append(utils.dice_coeff(np.flip(p, axis=i), ground))
266+
else:
267+
i=0
268+
scores.append(utils.dice_coeff(pred, ground))
245269
# if t < 1:
246270
# for i in range(3):
247271
# self._viewer.add_image(
@@ -254,7 +278,7 @@ def compute_dice(self):
254278

255279
# print(scores)
256280
score = max(scores)
257-
if score < 0.5:
281+
if score < threshold:
258282
# TODO add filename ?
259283
self._viewer.dims.ndisplay = 3
260284
self._viewer.add_image(
@@ -265,4 +289,4 @@ def compute_dice(self):
265289
)
266290
total_metrics.append(score)
267291
print(f"DICE METRIC :{total_metrics}")
268-
self.plot_dice(total_metrics)
292+
self.plot_dice(total_metrics, threshold)

napari_cellseg3d/plugin_model_training.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
22
import shutil
33
import warnings
4-
import zipfile
5-
import pandas as pd
64
from pathlib import Path
75

86
import matplotlib.pyplot as plt
@@ -81,7 +79,6 @@ def __init__(
8179
8280
TODO training plugin:
8381
84-
* Choice of validation proportion
8582
8683
* Custom model loading
8784
@@ -460,6 +457,15 @@ def build(self):
460457
################
461458
ui.add_blank(self, data_tab_layout)
462459
################
460+
461+
validation_group_w, validation_group_l = ui.make_group("Validation %")
462+
463+
validation_group_l.addWidget(self.validation_percent_choice)
464+
validation_group_w.setLayout(validation_group_l)
465+
data_tab_layout.addWidget(validation_group_w)
466+
################
467+
ui.add_blank(self, data_tab_layout)
468+
################
463469
# buttons
464470

465471
data_tab_layout.addWidget(
@@ -516,12 +522,6 @@ def build(self):
516522
#######################
517523
ui.add_blank(augment_tab_w, augment_tab_l)
518524
#######################
519-
validation_group_w, validation_group_l = ui.make_group("Validation %")
520-
521-
validation_group_l.addWidget(self.validation_percent_choice)
522-
validation_group_w.setLayout(validation_group_l)
523-
augment_tab_l.addWidget(validation_group_w)
524-
#######################
525525
augment_group_w, augment_group_l = ui.make_group("Augmentation")
526526
augment_group_l.addWidget(
527527
self.augment_choice, alignment=ui.LEFT_AL

0 commit comments

Comments
 (0)