Skip to content

Commit 480d714

Browse files
Update the SynapseNet trainign CLI
1 parent be0917a commit 480d714

File tree

3 files changed

+81
-24
lines changed

3 files changed

+81
-24
lines changed

doc/start_page.md

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,12 @@ For more options supported by the IMOD exports, please run `synapse_net.export_t
147147

148148
> Note: to use these commands you have to install IMOD.
149149
150+
SynapseNet also provides two CLI comamnds for training models, one for supervised network training (see [Supervised Training](#supervised-training) for details) and one for domain adaptation (see [Domain Adaptation](#domain-adaptation) for details).
151+
150152

151153
## Python Library
152154

153-
Using the `synapse_net` python library offers the most flexibility for using the SynapseNet functionality.
155+
Using the `synapse_net` python library offers the most flexibility for using SynapseNet's functionality.
154156
You can find an example analysis pipeline implemented with SynapseNet [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/analysis_pipeline.py).
155157

156158
We offer different functionality for segmenting and analyzing synapses in electron microscopy:
@@ -161,17 +163,32 @@ We offer different functionality for segmenting and analyzing synapses in electr
161163

162164
Please refer to the module documentation below for a full overview of our library's functionality.
163165

166+
### Supervised Training
167+
168+
SynapseNet provides functionality for training a UNet for segmentation tasks using supervised learning.
169+
In this case, you have to provide data **and** (manual) annotations for the structure(s) you want to segment.
170+
This functionality is implemented in `synapse_net.training.supervised_training`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/network_training.py).
171+
172+
We also provide a command line function to run supervised training: `synapse_net.run_supervised_training`. Run
173+
```bash
174+
synapse_net.run_supervised_training -h
175+
```
176+
for more information and instructions on how to use it.
177+
164178
### Domain Adaptation
165179

166-
We provide functionality for domain adaptation. It implements a special form of neural network training that can improve segmentation for data from a different condition (e.g. different sample preparation, electron microscopy technique or different specimen), **without requiring additional annotated structures**.
180+
SynapseNet provides functionality for (unsupervised) domain adaptation.
181+
This functionality is implemented through a student-teacher training approach that can improve segmentation for data from a different condition (for example different sample preparation, imaging technique, or different specimen), **without requiring additional annotated structures**.
167182
Domain adaptation is implemented in `synapse_net.training.domain_adaptation`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/domain_adaptation.py).
168183

169-
> Note: Domain adaptation only works if the initial model you adapt already finds some of the structures in the data from a new condition. If it does not work you will have to train a network on annotated data.
184+
We also provide a command line function to run domain adaptation: `synapse_net.run_domain_adaptation`. Run
185+
```bash
186+
synapse_net.run_domain_adaptation -h
187+
```
188+
for more information and instructions on how to use it.
170189

171-
### Network Training
190+
> Note: Domain adaptation only works if the initial model already finds some of the structures in the data from a new condition. If it does not work you will have to train a network on annotated data.
172191
173-
We also provide functionality for 'regular' neural network training. In this case, you have to provide data **and** manual annotations for the structure(s) you want to segment.
174-
This functionality is implemented in `synapse_net.training.supervised_training`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/network_training.py).
175192

176193
## Segmentation for the CryoET Data Portal
177194

synapse_net/training/domain_adaptation.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from sklearn.model_selection import train_test_split
1313

1414
from .semisupervised_training import get_unsupervised_loader
15-
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
15+
from .supervised_training import (
16+
get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
17+
)
1618
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
1719
from ..inference.util import _Scaler
1820

@@ -166,13 +168,11 @@ def mean_teacher_adaptation(
166168
"""
167169

168170

169-
def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir):
171+
def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
170172
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
171173
if len(files) == 0:
172174
raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
173175

174-
val_fraction = 0.15
175-
176176
# Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
177177
# And resave the volumes.
178178
resave_val_crops = len(files) < 4
@@ -235,30 +235,61 @@ def main():
235235
import argparse
236236

237237
parser = argparse.ArgumentParser(
238-
description=""
238+
description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
239+
"You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
240+
"synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa
241+
"The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa
242+
"You can then use this model for segmentation with the SynapseNet GUI or CLI. "
243+
"Check out the information below for details on the arguments of this function."
244+
)
245+
parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
246+
parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
247+
parser.add_argument("--file_pattern", default="*",
248+
help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
249+
parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa
250+
parser.add_argument(
251+
"--source_model",
252+
default="vesicles_3d",
253+
help="The source model used for weight initialization of teacher and student model. "
254+
"By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
255+
)
256+
parser.add_argument(
257+
"--resize_training_data", action="store_true",
258+
help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
239259
)
240-
parser.add_argument("--name", "-n", required=True)
241-
parser.add_argument("--input", "-i", required=True)
242-
parser.add_argument("--pattern", "-p", default="*.mrc")
243-
parser.add_argument("--source_model", default="vesicles_3d")
244-
parser.add_argument("--resize_training_data", action="store_true")
245-
parser.add_argument("--n_iterations", type=int, default=int(1e4))
246-
parser.add_argument("--patch_shape", nargs="+", type=int)
260+
parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
261+
parser.add_argument(
262+
"--patch_shape", nargs=3, type=int,
263+
help="The patch shape for training. By default the patch shape the source model was trained with is used."
264+
)
265+
266+
# More optional argument:
267+
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
268+
parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa
269+
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
270+
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
271+
parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa
272+
247273
args = parser.parse_args()
248274

249275
source_checkpoint = get_model_path(args.source_model)
250276
patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
251277
with tempfile.TemporaryDirectory() as tmp_dir:
252278
unsupervised_train_paths, unsupervised_val_paths = _get_paths(
253-
args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir
279+
args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction,
254280
)
281+
unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
255282

256283
mean_teacher_adaptation(
257284
name=args.name,
258285
unsupervised_train_paths=unsupervised_train_paths,
259286
unsupervised_val_paths=unsupervised_val_paths,
260287
patch_shape=patch_shape,
261288
source_checkpoint=source_checkpoint,
262-
raw_key="data",
289+
raw_key=raw_key,
263290
n_iterations=args.n_iterations,
291+
batch_size=args.batch_size,
292+
n_samples_train=args.n_samples_train,
293+
n_samples_val=args.n_samples_val,
294+
check=args.check,
264295
)

synapse_net/training/supervised_training.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,7 @@ def supervised_training(
306306
trainer.fit(n_iterations)
307307

308308

309-
def _parse_input_folder(folder, pattern, key):
310-
files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True))
309+
def _derive_key_from_files(files, key):
311310
# Get all file extensions (general wild-cards may pick up files with multiple extensions).
312311
extensions = list(set([os.path.splitext(ff)[1] for ff in files]))
313312

@@ -325,7 +324,7 @@ def _parse_input_folder(folder, pattern, key):
325324
# If the key is None and can't be derived raise an error.
326325
elif key is None and ext not in extension_to_key:
327326
raise ValueError(
328-
f"You have not passed a key for the data in {folder}, but the key could not be derived for{ext} format."
327+
f"You have not passed a key for the data in {ext} format, for which the key cannot be derived."
329328
)
330329
# If the key was passed and doesn't match the extension raise an error.
331330
elif key is not None and ext in extension_to_key and key != extension_to_key[ext]:
@@ -335,6 +334,11 @@ def _parse_input_folder(folder, pattern, key):
335334
return files, key
336335

337336

337+
def _parse_input_folder(folder, pattern, key):
338+
files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True))
339+
return _derive_key_from_files(files, key)
340+
341+
338342
def _parse_input_files(args):
339343
train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key)
340344
train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key)
@@ -366,7 +370,12 @@ def main():
366370
import argparse
367371

368372
parser = argparse.ArgumentParser(
369-
description="Train a model for foreground and boundary segmentation via supervised learning."
373+
description="Train a model for foreground and boundary segmentation via supervised learning.\n\n"
374+
"You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n" # noqa
375+
"synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n" # noqa
376+
"The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)." # noqa
377+
"You can then use this model for segmentation with the SynapseNet GUI or CLI. "
378+
"Check out the information below for details on the arguments of this function."
370379
)
371380
parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.")
372381
parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.")

0 commit comments

Comments
 (0)