Skip to content

Commit 70628f6

Browse files
Implement CLI for supervised training
1 parent 0a8101e commit 70628f6

File tree

2 files changed

+123
-3
lines changed

2 files changed

+123
-3
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"synapse_net.run_segmentation = synapse_net.tools.cli:segmentation_cli",
1717
"synapse_net.export_to_imod_points = synapse_net.tools.cli:imod_point_cli",
1818
"synapse_net.export_to_imod_objects = synapse_net.tools.cli:imod_object_cli",
19+
"synapse_net.run_supervised_training = synapse_net.training.supervised_training:main",
1920
],
2021
"napari.manifest": [
2122
"synapse_net = synapse_net:napari.yaml",

synapse_net/training/supervised_training.py

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import os
2+
from glob import glob
13
from typing import Optional, Tuple
24

35
import torch
46
import torch_em
7+
from sklearn.model_selection import train_test_split
58
from torch_em.model import AnisotropicUNet, UNet2d
69

710

@@ -95,6 +98,7 @@ def get_supervised_loader(
9598
sampler: Optional[callable] = None,
9699
ignore_label: Optional[int] = None,
97100
label_transform: Optional[callable] = None,
101+
label_paths: Optional[Tuple[str]] = None,
98102
**loader_kwargs,
99103
) -> torch.utils.data.DataLoader:
100104
"""Get a dataloader for supervised segmentation training.
@@ -118,6 +122,8 @@ def get_supervised_loader(
118122
ignored in the loss computation. By default this option is not used.
119123
label_transform: Label transform that is applied to the segmentation to compute the targets.
120124
If no label transform is passed (the default) a boundary transform is used.
125+
label_paths: Optional paths containing the labels / annotations for training.
126+
If not given, the labels are expected to be contained in the `data_paths`.
121127
loader_kwargs: Additional keyword arguments for the dataloader.
122128
123129
Returns:
@@ -155,9 +161,14 @@ def get_supervised_loader(
155161
if sampler is None:
156162
sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
157163

164+
if label_paths is None:
165+
label_paths = data_paths
166+
elif len(label_paths) != len(data_paths):
167+
raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}")
168+
158169
loader = torch_em.default_segmentation_loader(
159170
data_paths, raw_key,
160-
data_paths, label_key, sampler=sampler,
171+
label_paths, label_key, sampler=sampler,
161172
batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
162173
is_seg_dataset=True, label_transform=label_transform, transform=transform,
163174
num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
@@ -177,6 +188,8 @@ def supervised_training(
177188
batch_size: int = 1,
178189
lr: float = 1e-4,
179190
n_iterations: int = int(1e5),
191+
train_label_paths: Optional[Tuple[str]] = None,
192+
val_label_paths: Optional[Tuple[str]] = None,
180193
train_rois: Optional[Tuple[Tuple[slice]]] = None,
181194
val_rois: Optional[Tuple[Tuple[slice]]] = None,
182195
sampler: Optional[callable] = None,
@@ -210,6 +223,10 @@ def supervised_training(
210223
batch_size: The batch size for training.
211224
lr: The initial learning rate.
212225
n_iterations: The number of iterations to train for.
226+
train_label_paths: Optional paths containing the label data for training.
227+
If not given, the labels are expected to be part of `train_paths`.
228+
val_label_paths: Optional paths containing the label data for validation.
229+
If not given, the labels are expected to be part of `val_paths`.
213230
train_rois: Optional region of interests for training.
214231
val_rois: Optional region of interests for validation.
215232
sampler: Optional sampler for selecting blocks for training.
@@ -231,11 +248,11 @@ def supervised_training(
231248
train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
232249
n_samples=n_samples_train, rois=train_rois, sampler=sampler,
233250
ignore_label=ignore_label, label_transform=label_transform,
234-
**loader_kwargs)
251+
label_paths=train_label_paths, **loader_kwargs)
235252
val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
236253
n_samples=n_samples_val, rois=val_rois, sampler=sampler,
237254
ignore_label=ignore_label, label_transform=label_transform,
238-
**loader_kwargs)
255+
label_paths=val_label_paths, **loader_kwargs)
239256

240257
if check:
241258
from torch_em.util.debug import check_loader
@@ -287,3 +304,105 @@ def supervised_training(
287304
metric=metric,
288305
)
289306
trainer.fit(n_iterations)
307+
308+
309+
def _parse_input_folder(folder, pattern, key):
310+
files = sorted(glob(os.path.join(folder, "**", pattern)))
311+
# Get all file extensions (general wild-cards may pick up files with multiple extensions).
312+
extensions = [os.path.splitext(ff)[1] for ff in files]
313+
314+
# If we have more than 1 file extension we just use the key that was passed,
315+
# as it is unclear how to derive a consistent key.
316+
if len(extensions) > 1:
317+
return files, key
318+
319+
ext = extensions[0]
320+
extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"}
321+
322+
# Derive the key from the extension if the key is None.
323+
if key is None and ext in extension_to_key:
324+
key = extension_to_key[ext]
325+
# If the key is None and can't be derived raise an error.
326+
elif key is None and ext not in extension_to_key:
327+
raise ValueError(
328+
f"You have not passed a key for the data in {folder}, but the key could not be derived for{ext} format."
329+
)
330+
# If the key was passed and doesn't match the extension raise an error.
331+
elif key is not None and ext in extension_to_key and key != extension_to_key[ext]:
332+
raise ValueError(
333+
f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}."
334+
)
335+
return files, key
336+
337+
338+
def _parse_input_files(args):
339+
train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key)
340+
train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key)
341+
if len(train_image_paths) != len(train_label_paths):
342+
raise ValueError(
343+
f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match."
344+
f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}."
345+
)
346+
347+
if args.val_folder is None:
348+
if args.val_label_folder is not None:
349+
raise ValueError("You have passed a val_label_folder, but not a val_folder.")
350+
train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split(
351+
train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42
352+
)
353+
else:
354+
if args.val_label_folder is None:
355+
raise ValueError("You have passed a val_folder, but not a val_label_folder.")
356+
val_image_paths = _parse_input_folder(args.val_image_folder, args.image_file_pattern, raw_key)
357+
val_label_paths = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key)
358+
359+
return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key
360+
361+
362+
# TODO enable initialization with a pre-trained model.
363+
def main():
364+
"""@private
365+
"""
366+
import argparse
367+
368+
parser = argparse.ArgumentParser(
369+
description="Train a model for foreground and boundary segmentation via supervised learning."
370+
)
371+
parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.")
372+
parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.")
373+
374+
# Folders with training data, containing raw/image data and labels.
375+
parser.add_argument("--i", "--train_folder", required=True, help="The input folder with the training image data.")
376+
parser.add_argument("--image_file_pattern", default="*",
377+
help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.")
378+
parser.add_argument("--raw_key",
379+
help="The internal path for the raw data. If not given, will be determined based on the file extension.") # noqa
380+
parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.")
381+
parser.add_argument("--label_file_pattern", default="*",
382+
help="The pattern for selecting label files. For example, '*.tif' to select all tif files.")
383+
parser.add_argument("--label_key",
384+
help="The internal path for the label data. If not given, will be determined based on the file extension.") # noqa
385+
386+
# Optional folders with validation data. If not given the training data is split into train/val.
387+
parser.add_argument("--val_folder",
388+
help="The input folder with the validation data. If not given the training data will be split for validation") # noqa
389+
parser.add_argument("--val_label_folder",
390+
help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa
391+
392+
# More optional argument:
393+
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
394+
parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa
395+
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
396+
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
397+
args = parser.parse_args()
398+
399+
train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\
400+
_parse_input_files(args)
401+
402+
supervised_training(
403+
name=args.name, train_paths=train_image_paths, val_paths=val_image_paths,
404+
train_label_paths=train_label_paths, val_label_paths=val_label_paths,
405+
raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size,
406+
n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val,
407+
check=args.check,
408+
)

0 commit comments

Comments
 (0)