Skip to content

Commit c578bd9

Browse files
authored
WNet training (#47)
* Add LayerNorm * Change softmax arg * Num group 2 * Update model.py * Update model.py * Reduce depth of WNet * Started WNet training UI * Workable WNet training prototype * Fixes * Test fixes * Temp fix for CRF (#46) * Minor fixes * Tests & training * Fix tests + new weights * Fix ETA precision * Docstring update * Update plugin_model_training.py * Update contrast limit when updating layers * Update config.py * Fixed normalization * Update plugin_model_training.py * Update workers_utils.py * Trying to fix input normalization * Fix name mismatch * Fix decoder evaluation * Update dice calculation * Update dice coeff * Update worker_training.py * Fix eval detach * Fix Dice list for WNet * Updated validation UI * Tooltips and show_results update * Plots update * Plot + log_parameters * Update worker_training.py * Disable WANDB for now + log param tweaks * UI/log tweaks * Functional WNet training * Clean exit / free memory attempt * Cleanup + tests - Removed previous train script - Fix tests - Enable test workflow on GH * Deploy memory usage fix in inference as well * Memory usage fix * UI tweak * WNet cleanup + supervised training improvements * Change Dice metric include_background for WNet To avoid Max Dice calculation * Set better default LR across un/supervised * Update model.py * Update WNet weights * Fix default LR + sup. test * Fix new unsup LR in tests * Fix dir for saving in tests * Testing fixes Due to Singleton Trainer widget * Test unsupervised training and raise coverage * WNet eval test * Fix order for model deletion * Extend supervised train tests * Started docs update * Update plugin_model_training.py * Fixed filepaths * Fix paths in test (use pathlib) * Updated workers config * Fixed parse_default_path test * Ignore wandb results in gitignore * Enable GH Actions tests on branch temporarily * Fixed deletion of Qt imports in interface * Reverted include_background=True in Dice * Reintroduced best Dice channel seeking + refacto * Improve filepath messages * Fix unsup image loading when not validating * Fix training tests
1 parent 7a4e31f commit c578bd9

32 files changed

+2644
-1759
lines changed

.github/workflows/test_and_deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77
push:
88
branches:
99
- main
10-
- cy/utils
10+
- cy/wnet-train
1111
tags:
1212
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
1313
pull_request:

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ venv/
103103
/docs/res/logo/old_logo/
104104
/reqs/
105105
/loss_plots/
106+
/wandb/
106107
notebooks/csv_cell_plot.html
107108
notebooks/full_plot.html
108109
*.csv

docs/res/code/model_framework.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Class : ModelFramework
1212
Methods
1313
**********************
1414
.. autoclass:: napari_cellseg3d.code_models.model_framework::ModelFramework
15-
:members: __init__, send_log, save_log, save_log_to_path, display_status_report, create_train_dataset_dict, get_model, get_available_models, get_device, empty_cuda_cache
15+
:members: __init__, send_log, save_log, save_log_to_path, display_status_report, create_train_dataset_dict, get_available_models, get_device, empty_cuda_cache
1616
:noindex:
1717

1818

docs/res/code/plugin_model_training.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ Class : Trainer
1111
Methods
1212
**********************
1313
.. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer
14-
:members: __init__, get_loss, check_ready, send_log, start, on_start, on_finish, on_error, on_yield, plot_loss, update_loss_plot
14+
:members: __init__, check_ready, send_log, start, on_start, on_finish, on_error, on_yield, update_loss_plot
1515
:noindex:
1616

1717

1818

1919
Attributes
2020
*********************
2121
.. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer
22-
:members: _viewer, worker, loss_dict, canvas, train_loss_plot, dice_metric_plot
22+
:members: _viewer, worker, canvas

docs/res/code/workers.rst

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Class : LogSignal
1010

1111
Attributes
1212
************************
13-
.. autoclass:: napari_cellseg3d.code_models.workers::LogSignal
13+
.. autoclass:: napari_cellseg3d.code_models.workers_utils::LogSignal
1414
:members: log_signal
1515
:noindex:
1616

@@ -24,21 +24,47 @@ Class : InferenceWorker
2424

2525
Methods
2626
************************
27-
.. autoclass:: napari_cellseg3d.code_models.workers::InferenceWorker
27+
.. autoclass:: napari_cellseg3d.code_models.worker_inference::InferenceWorker
2828
:members: __init__, log, create_inference_dict, inference
2929
:noindex:
3030

3131
.. _here: https://napari-staging-site.github.io/guides/stable/threading.html
3232

3333

34-
Class : TrainingWorker
34+
Class : TrainingWorkerBase
3535
-------------------------------------------
3636

3737
.. important::
3838
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`
3939

4040
Methods
4141
************************
42-
.. autoclass:: napari_cellseg3d.code_models.workers::TrainingWorker
42+
.. autoclass:: napari_cellseg3d.code_models.worker_training::TrainingWorkerBase
4343
:members: __init__, log, train
4444
:noindex:
45+
46+
47+
Class : WNetTrainingWorker
48+
-------------------------------------------
49+
50+
.. important::
51+
Inherits from :py:class:`TrainingWorkerBase`
52+
53+
Methods
54+
************************
55+
.. autoclass:: napari_cellseg3d.code_models.worker_training::WNetTrainingWorker
56+
:members: __init__, train, eval, get_patch_dataset, get_dataset_eval, get_dataset
57+
:noindex:
58+
59+
60+
Class : SupervisedTrainingWorker
61+
-------------------------------------------
62+
63+
.. important::
64+
Inherits from :py:class:`TrainingWorkerBase`
65+
66+
Methods
67+
************************
68+
.. autoclass:: napari_cellseg3d.code_models.worker_training::SupervisedTrainingWorker
69+
:members: __init__, train
70+
:noindex:

docs/res/guides/detailed_walkthrough.rst

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ Finally, the last tab lets you choose :
120120

121121
* SegResNet is a lightweight model (low memory requirements) from MONAI originally designed for 3D fMRI data.
122122
* VNet is a larger (than SegResNet) CNN from MONAI designed for medical image segmentation.
123-
* TRAILMAP is our PyTorch implementation of a 3D CNN model trained for axonal detection in cleared tissue.
124123
* TRAILMAP_MS is our implementation in PyTorch additionally trained on mouse cortical neural nuclei from mesoSPIM data.
125-
* Note, the code is very modular, so it is relatively straightforward to use (and contribute) your model as well.
124+
* SwinUNetR is a MONAI implementation of the SwinUNetR model. It is costly in compute and memory, but can achieve high performance.
125+
* WNet is our reimplementation of an unsupervised model, which can be used to produce segmentation without labels.
126126

127127

128128
* The loss : for object detection in 3D volumes you'll likely want to use the Dice or Dice-focal Loss.
@@ -239,13 +239,12 @@ Scoring, review, analysis
239239
----------------------------
240240

241241

242-
.. Using the metrics utility module, you can compare the model's predictions to any ground truth
243-
labels you might have.
244-
Simply provide your prediction and ground truth labels, and compute the results.
245-
A Dice metric of 1 indicates perfect matching, whereas a score of 0 indicates complete mismatch.
246-
Select which score **you consider as sub-optimal**, and all results below this will be **shown in napari**.
247-
If at any time the **orientation of your prediction labels changed compared to the ground truth**, check the
248-
"Find best orientation" option to compensate for it.
242+
.. Using the metrics utility module, you can compare the model's predictions to any ground truth labels you might have.
243+
Simply provide your prediction and ground truth labels, and compute the results.
244+
A Dice metric of 1 indicates perfect matching, whereas a score of 0 indicates complete mismatch.
245+
Select which score **you consider as sub-optimal**, and all results below this will be **shown in napari**.
246+
If at any time the **orientation of your prediction labels changed compared to the ground truth**, check the
247+
"Find best orientation" option to compensate for it.
249248
250249
251250
Labels review

docs/res/guides/training_module_guide.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Training module guide - Unsupervised models
44
==============================================
55

66
.. important::
7-
The WNet training is for now only available in the provided jupyter notebook, in the ``notebooks`` folder.
7+
The WNet training is for now available as part of the plugin in the Training module.
88
Please see the :ref:`training_wnet` section for more information.
99

1010
Training module guide - Supervised models
@@ -25,14 +25,15 @@ Model Link to original paper
2525
============== ================================================================================================
2626
VNet `Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation`_
2727
SegResNet `3D MRI brain tumor segmentation using autoencoder regularization`_
28-
TRAILMAP_MS A PyTorch implementation of the `TRAILMAP project on GitHub`_ pretrained with MesoSpim data
29-
TRAILMAP An implementation of the `TRAILMAP project on GitHub`_ using a `3DUNet for PyTorch`_
28+
TRAILMAP_MS An implementation of the `TRAILMAP project on GitHub`_ using `3DUNet for PyTorch`_
29+
SwinUNetR `Swin UNETR, Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images`_
3030
============== ================================================================================================
3131

3232
.. _Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation: https://arxiv.org/pdf/1606.04797.pdf
3333
.. _3D MRI brain tumor segmentation using autoencoder regularization: https://arxiv.org/pdf/1810.11654.pdf
3434
.. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP
3535
.. _3DUnet for Pytorch: https://github.com/wolny/pytorch-3dunet
36+
.. _Swin UNETR, Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images: https://arxiv.org/abs/2201.01266
3637

3738
.. important::
3839
| The machine learning models used by this program require all images of a dataset to be of the same size.

docs/res/guides/training_wnet.rst

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,45 @@ the model was trained on; you can retrain from our pretrained model to your set
1515
The model has two losses, the SoftNCut loss which clusters pixels according to brightness, and a reconstruction loss, either
1616
Mean Square Error (MSE) or Binary Cross Entropy (BCE).
1717
Unlike the original paper, these losses are added in a weighted sum and the backward pass is performed for the whole model at once.
18-
The SoftNcuts is bounded between 0 and 1; the MSE may take large values.
18+
The SoftNcuts is bounded between 0 and 1; the MSE may take large positive values.
1919

20-
For good performance, one should wait for the SoftNCut to reach a plateau, the reconstruction loss must also diminish but it's generally less critical.
20+
For good performance, one should wait for the SoftNCut to reach a plateau; the reconstruction loss must also diminish but it's generally less critical.
2121

22+
Parameters
23+
-------------------------------
24+
25+
When using the WNet training module, additional options will be provided in the Advanced tab of the training module:
26+
27+
- Number of classes : number of classes to segment (default 2). Additional classes will result in a more progressive segmentation according to brightness; can be useful if you have "halos" around your objects or artifacts with a significantly different brightness.
28+
- Reconstruction loss : either MSE or BCE (default MSE). MSE is more sensitive to outliers, but can be more precise; BCE is more robust to outliers but can be less precise.
29+
30+
- NCuts parameters:
31+
- Intensity sigma : standard deviation of the feature similarity term (brightness here, default 1)
32+
- Spatial sigma : standard deviation of the spatial proximity term (default 4)
33+
- Radius : radius of the loss computation in pixels (default 2)
34+
35+
.. note::
36+
Intensity sigma depends on pixel values in the image. The default of 1 is tailored to images being mapped between 0 and 100, which is done automatically by the plugin.
37+
.. note::
38+
Raising the radius might improve performance in some cases, but will also greatly increase computation time.
39+
40+
- Weights for the sum of losses :
41+
- NCuts weight : weight of the NCuts loss (default 0.5)
42+
- Reconstruction weight : weight of the reconstruction loss (default 0.5*1e-2)
43+
44+
.. note::
45+
The weight of the reconstruction loss should be adjusted according to its empirical value; ideally the reconstruction loss should be of the same order of magnitude as the NCuts loss after being multiplied by its weight.
2246

2347
Common issues troubleshooting
2448
------------------------------
25-
If you do not find a satisfactory answer here, please `open an issue`_ !
49+
If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub.
2650

27-
- **The NCuts loss explodes after a few epochs** : Lower the learning rate
51+
- **The NCuts loss explodes after a few epochs** : Lower the learning rate, first by a factor of two, then ten.
2852

2953
- **The NCuts loss does not converge and is unstable** :
30-
The normalization step might not be adapted to your images. Disable normalization and change intensity_sigma according to the distribution of values in your image; for reference, by default images are remapped to values between 0 and 100, and intensity_sigma=1.
54+
The normalization step might not be adapted to your images. Disable normalization and change intensity_sigma according to the distribution of values in your image. For reference, by default images are remapped to values between 0 and 100, and intensity_sigma=1.
3155

32-
- **Reconstruction (decoder) performance is poor** : switch to BCE and set the scaling factor of the reconstruction loss ot 0.5, OR adjust the weight of the MSE loss to make it closer to 1.
56+
- **Reconstruction (decoder) performance is poor** : switch to BCE and set the scaling factor of the reconstruction loss ot 0.5, OR adjust the weight of the MSE loss to make it closer to 1 in the weighted sum.
3357

3458

3559
.. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506

napari_cellseg3d/_tests/fixtures.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from qtpy.QtWidgets import QTextEdit
23

34
from napari_cellseg3d.utils import LOGGER as logger
@@ -17,3 +18,60 @@ def warn(self, warning):
1718

1819
def error(self, e):
1920
raise (e)
21+
22+
23+
class WNetFixture(torch.nn.Module):
24+
def __init__(self):
25+
super().__init__()
26+
self.mock_conv = torch.nn.Conv3d(1, 1, 1)
27+
self.mock_conv.requires_grad_(False)
28+
29+
def forward_encoder(self, x):
30+
return x
31+
32+
def forward_decoder(self, x):
33+
return x
34+
35+
def forward(self, x):
36+
return self.forward_encoder(x), self.forward_decoder(x)
37+
38+
39+
class ModelFixture(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
self.mock_conv = torch.nn.Conv3d(1, 1, 1)
43+
self.mock_conv.requires_grad_(False)
44+
45+
def forward(self, x):
46+
return x
47+
48+
49+
class OptimizerFixture:
50+
def __init__(self):
51+
self.param_groups = []
52+
self.param_groups.append({"lr": 0})
53+
54+
def zero_grad(self):
55+
pass
56+
57+
def step(self, *args):
58+
pass
59+
60+
61+
class SchedulerFixture:
62+
def step(self, *args):
63+
pass
64+
65+
66+
class LossFixture:
67+
def __call__(self, *args):
68+
return self
69+
70+
def backward(self, *args):
71+
pass
72+
73+
def item(self):
74+
return 0
75+
76+
def detach(self):
77+
return self

napari_cellseg3d/_tests/test_inference.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_onnx_inference(make_napari_viewer_proxy):
2323
path = str(Path(PRETRAINED_WEIGHTS_DIR).resolve() / "wnet.onnx")
2424
assert Path(path).is_file()
2525
dims = 64
26-
batch = 2
26+
batch = 1
2727
x = torch.randn(size=(batch, 1, dims, dims, dims))
2828
worker = ONNXModelWrapper(file_location=path)
2929
assert worker.eval() is None
@@ -66,19 +66,23 @@ def test_inference_on_folder():
6666
config.images_filepaths = [
6767
str(Path(__file__).resolve().parent / "res/test.tif")
6868
]
69-
config.sliding_window_config.window_size = 64
69+
70+
config.sliding_window_config.window_size = 8
7071

7172
class mock_work:
72-
def __call__(self, x):
73-
return x
73+
@staticmethod
74+
def eval():
75+
return True
7476

75-
def eval(self):
76-
return None
77+
def __call__(self, x):
78+
return torch.Tensor(x)
7779

7880
worker = InferenceWorker(worker_config=config)
7981
worker.aniso_transform = mock_work()
8082

81-
image = torch.Tensor(rand_gen.random((1, 1, 64, 64, 64)))
83+
image = torch.Tensor(rand_gen.random(size=(1, 1, 8, 8, 8)))
84+
assert image.shape == (1, 1, 8, 8, 8)
85+
assert image.dtype == torch.float32
8286
res = worker.inference_on_folder(
8387
{"image": image},
8488
0,

0 commit comments

Comments
 (0)