Skip to content

Commit 0c16964

Browse files
committed
WIP added yield mid-training + docs + fixes and cleanup
1 parent 733ad9d commit 0c16964

File tree

6 files changed

+46
-32
lines changed

6 files changed

+46
-32
lines changed

docs/res/guides/training_module_guide.rst

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ TRAILMAP An emulation of the `TRAILMAP project on GitHub`_ using `3DUne
2323
.. _3DUnet for Pytorch: https://github.com/wolny/pytorch-3dunet
2424

2525
.. important::
26-
The machine learning models used by this program require all images of a dataset to be of the same size.
27-
Please ensure that all the images you are loading are of the **same size**, or to use the **"extract patches" (in augmentation tab)** with an appropriately small size
28-
to ensure all images being used by the model are of a workable size.
26+
| The machine learning models used by this program require all images of a dataset to be of the same size.
27+
| Please ensure that all the images you are loading are of the **same size**, or to use the **"extract patches" (in augmentation tab)** with an appropriately small size to ensure all images being used by the model are of a workable size.
28+
29+
.. important::
30+
| **All image sizes used should be as close to a power of two as possible, if not a power of two.**
31+
| Images are automatically padded; a 64 pixels cube will be used as is, but a 65 pixel cube will be padded up to 128 pixels, resulting in much higher memory use.
2932
3033
The training module is comprised of several tabs.
3134

@@ -41,36 +44,39 @@ The training module is comprised of several tabs.
4144
* Whether to use images "as is" (**requires all images to be of the same size and cubic**) or extract patches.
4245

4346
* If you're extracting patches :
44-
* The size of patches to be extracted (ideally, please use a value **close to a power of two**, such as 120 or 60.
45-
* The number of samples to extract from each of your image to ensure correct size and perform data augmentation. A larger number will likely mean better performances, but longer training and larger memory usage.
47+
48+
* The size of patches to be extracted (ideally, please use a value **close to a power of two**, such as 120 or 60 to ensure correct size.)
49+
* The number of samples to extract from each of your images. A larger number will likely mean better performances, but longer training and larger memory usage.
50+
4651
* Whether to perform data augmentation or not (elastic deforms, intensity shifts. random flipping,etc). A rule of thumb for augmentation is :
52+
4753
* If you're using the patch extraction method, enable it if you are using more than 10 samples per image with at least 5 images
4854
* If you have a large dataset and are not using patches extraction, enable it.
4955

5056

5157
3) The third contains training related parameters :
5258

53-
* The model to use for training (see table above)
54-
* The loss function used for training (see table below)
55-
* The batch size (larger means quicker training and possibly better performance but increased memory usage)
56-
* The number of epochs (a possibility is to start with 60 epochs, and decrease or increase depending on performance.)
57-
* The epoch interval for validation (for example, if set to two, the module will use the validation dataset to evaluate the model with the dice metric every two epochs.)
59+
* The **model** to use for training (see table above)
60+
* The **loss function** used for training (see table below)
61+
* The **batch size** (larger means quicker training and possibly better performance but increased memory usage)
62+
* The **number of epochs** (a possibility is to start with 60 epochs, and decrease or increase depending on performance.)
63+
* The **epoch interval** for validation (for example, if set to two, the module will use the validation dataset to evaluate the model with the dice metric every two epochs.)
5864

59-
If the dice metric is better on that validation interval, the model weights will be saved in the results folder.
65+
.. note::
66+
If the dice metric is better on a given validation interval, the model weights will be saved in the results folder.
6067

6168
The available loss functions are :
6269

63-
======================== ====================================================
70+
======================== ================================================================================================
6471
Function Reference
65-
======================== ====================================================
72+
======================== ================================================================================================
6673
Dice loss `Dice Loss from MONAI`_ with ``sigmoid=true``
6774
Focal loss `Focal Loss from MONAI`_
6875
Dice-Focal loss `Dice-focal Loss from MONAI`_ with ``sigmoid=true`` and ``lambda_dice = 0.5``
6976
Generalized Dice loss `Generalized dice Loss from MONAI`_ with ``sigmoid=true``
7077
Dice-CE loss `Dice-CE Loss from MONAI`_ with ``sigmoid=true``
7178
Tversky loss `Tversky Loss from MONAI`_ with ``sigmoid=true``
72-
======================== ====================================================
73-
79+
======================== ================================================================================================
7480
.. _Dice Loss from MONAI: https://docs.monai.io/en/stable/losses.html#diceloss
7581
.. _Focal Loss from MONAI: https://docs.monai.io/en/stable/losses.html#focalloss
7682
.. _Dice-focal Loss from MONAI: https://docs.monai.io/en/stable/losses.html#dicefocalloss

src/napari_cellseg3d/launch_review.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,19 @@ def launch_review(
6767
6868
6969
"""
70-
global slicer # Todo : is this okay ? ask Max
71-
global z_pos
72-
global view1
73-
global layer
74-
global images_original
75-
global base_label
70+
# global slicer # Todo : is this okay ? ask Max. seems to work without, keep an eye on it
71+
# global z_pos
72+
# global view1
73+
# global layer
74+
# global images_original
75+
# global base_label
7676
images_original = original
7777
base_label = base
78-
try:
79-
del view1
80-
del layer
81-
except NameError:
82-
pass
78+
# try:
79+
# del view1
80+
# del layer
81+
# except NameError:
82+
# pass
8383

8484
view1 = viewer
8585
view1.add_image(

src/napari_cellseg3d/model_workers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ def train(self):
757757
f"* {step}/{len(train_ds) // train_loader.batch_size}, "
758758
f"Train loss: {loss.detach().item():.4f}"
759759
)
760+
yield {"plot":False, "weights": model.state_dict()}
760761

761762
epoch_loss /= step
762763
epoch_loss_values.append(epoch_loss)
@@ -804,6 +805,7 @@ def train(self):
804805
val_metric_values.append(metric)
805806

806807
train_report = {
808+
"plot": True,
807809
"epoch": epoch,
808810
"losses": epoch_loss_values,
809811
"val_metrics": val_metric_values,

src/napari_cellseg3d/plugin_model_training.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -873,10 +873,11 @@ def on_yield(data, widget):
873873
# print(
874874
# f"\nCatching results : for epoch {data['epoch']}, loss is {data['losses']} and validation is {data['val_metrics']}"
875875
# )
876-
widget.progress.setValue(
877-
100 * (data["epoch"] + 1) // widget.max_epochs
878-
)
879-
widget.update_loss_plot(data["losses"], data["val_metrics"])
876+
if data["plot"]:
877+
widget.progress.setValue(
878+
100 * (data["epoch"] + 1) // widget.max_epochs
879+
)
880+
widget.update_loss_plot(data["losses"], data["val_metrics"])
880881

881882
if widget.stop_requested:
882883
torch.save(

src/napari_cellseg3d/plugin_review.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def run_review(self):
231231
warnings.warn(
232232
"Opening several loader sessions in one window is not supported; opening in new window"
233233
)
234-
self._viewer.remove_from_viewer()
234+
self._viewer.close()
235235
else:
236236
viewer = self._viewer
237237
print("new sess")

src/napari_cellseg3d/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,14 @@ def get_padding_dim(image_shape, anisotropy_factor=None):
149149
# problems with zero divs avoided via params for spinboxes
150150
size = int(size / anisotropy_factor[i])
151151
while pad < size:
152+
153+
if size-pad < 30 :
154+
warnings.warn(f"Your value is close to a lower power of two; you might want to choose slightly smaller"
155+
f" sizes and/or crop your images down to {pad}")
156+
152157
pad = 2**n
153158
n += 1
154-
if pad >= 1024:
159+
if pad >= 256:
155160
warnings.warn(
156161
"Warning : a very large dimension for automatic padding has been computed.\n"
157162
"Ensure your images are of an appropriate size and/or that you have enough memory."

0 commit comments

Comments
 (0)