|
12 | 12 | from sklearn.model_selection import train_test_split |
13 | 13 |
|
14 | 14 | from .semisupervised_training import get_unsupervised_loader |
15 | | -from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim |
| 15 | +from .supervised_training import ( |
| 16 | + get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files |
| 17 | +) |
16 | 18 | from ..inference.inference import get_model_path, compute_scale_from_voxel_size |
17 | 19 | from ..inference.util import _Scaler |
18 | 20 |
|
@@ -166,13 +168,11 @@ def mean_teacher_adaptation( |
166 | 168 | """ |
167 | 169 |
|
168 | 170 |
|
169 | | -def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir): |
| 171 | +def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction): |
170 | 172 | files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) |
171 | 173 | if len(files) == 0: |
172 | 174 | raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}") |
173 | 175 |
|
174 | | - val_fraction = 0.15 |
175 | | - |
176 | 176 | # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation. |
177 | 177 | # And resave the volumes. |
178 | 178 | resave_val_crops = len(files) < 4 |
@@ -235,30 +235,61 @@ def main(): |
235 | 235 | import argparse |
236 | 236 |
|
237 | 237 | parser = argparse.ArgumentParser( |
238 | | - description="" |
| 238 | + description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n" |
| 239 | + "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n" |
| 240 | + "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa |
| 241 | + "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa |
| 242 | + "You can then use this model for segmentation with the SynapseNet GUI or CLI. " |
| 243 | + "Check out the information below for details on the arguments of this function." |
| 244 | + ) |
| 245 | + parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ") |
| 246 | + parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.") |
| 247 | + parser.add_argument("--file_pattern", default="*", |
| 248 | + help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.") |
| 249 | + parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa |
| 250 | + parser.add_argument( |
| 251 | + "--source_model", |
| 252 | + default="vesicles_3d", |
| 253 | + help="The source model used for weight initialization of teacher and student model. " |
| 254 | + "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used." |
| 255 | + ) |
| 256 | + parser.add_argument( |
| 257 | + "--resize_training_data", action="store_true", |
| 258 | + help="Whether to resize the training data to fit the voxel size of the source model's trainign data." |
239 | 259 | ) |
240 | | - parser.add_argument("--name", "-n", required=True) |
241 | | - parser.add_argument("--input", "-i", required=True) |
242 | | - parser.add_argument("--pattern", "-p", default="*.mrc") |
243 | | - parser.add_argument("--source_model", default="vesicles_3d") |
244 | | - parser.add_argument("--resize_training_data", action="store_true") |
245 | | - parser.add_argument("--n_iterations", type=int, default=int(1e4)) |
246 | | - parser.add_argument("--patch_shape", nargs="+", type=int) |
| 260 | + parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.") |
| 261 | + parser.add_argument( |
| 262 | + "--patch_shape", nargs=3, type=int, |
| 263 | + help="The patch shape for training. By default the patch shape the source model was trained with is used." |
| 264 | + ) |
| 265 | + |
| 266 | + # More optional argument: |
| 267 | + parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") |
| 268 | + parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa |
| 269 | + parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa |
| 270 | + parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa |
| 271 | + parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa |
| 272 | + |
247 | 273 | args = parser.parse_args() |
248 | 274 |
|
249 | 275 | source_checkpoint = get_model_path(args.source_model) |
250 | 276 | patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) |
251 | 277 | with tempfile.TemporaryDirectory() as tmp_dir: |
252 | 278 | unsupervised_train_paths, unsupervised_val_paths = _get_paths( |
253 | | - args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir |
| 279 | + args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction, |
254 | 280 | ) |
| 281 | + unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key) |
255 | 282 |
|
256 | 283 | mean_teacher_adaptation( |
257 | 284 | name=args.name, |
258 | 285 | unsupervised_train_paths=unsupervised_train_paths, |
259 | 286 | unsupervised_val_paths=unsupervised_val_paths, |
260 | 287 | patch_shape=patch_shape, |
261 | 288 | source_checkpoint=source_checkpoint, |
262 | | - raw_key="data", |
| 289 | + raw_key=raw_key, |
263 | 290 | n_iterations=args.n_iterations, |
| 291 | + batch_size=args.batch_size, |
| 292 | + n_samples_train=args.n_samples_train, |
| 293 | + n_samples_val=args.n_samples_val, |
| 294 | + check=args.check, |
264 | 295 | ) |
0 commit comments