Skip to content

Commit 292e450

Browse files
committed
fix domain_adaptation.py
2 parents 712714f + 702138f commit 292e450

File tree

1 file changed

+146
-3
lines changed

1 file changed

+146
-3
lines changed

synapse_net/training/domain_adaptation.py

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
import os
2+
import tempfile
3+
from glob import glob
4+
from pathlib import Path
25
from typing import Optional, Tuple
36

7+
import mrcfile
48
import torch
59
import torch_em
610
import torch_em.self_training as self_training
11+
from elf.io import open_file
12+
from sklearn.model_selection import train_test_split
713

814
from .semisupervised_training import get_unsupervised_loader
9-
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
15+
from .supervised_training import (
16+
get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
17+
)
18+
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
19+
from ..inference.util import _Scaler
1020

1121
class NewPseudoLabeler(self_training.DefaultPseudoLabeler):
1222
"""Compute pseudo labels based on model predictions, typically from a teacher model.
@@ -109,7 +119,7 @@ def mean_teacher_adaptation(
109119
pseudo_label_sampler: Optional[callable] = None,
110120
device: int = 0,
111121
) -> None:
112-
"""Run domain adapation to transfer a network trained on a source domain for a supervised
122+
"""Run domain adaptation to transfer a network trained on a source domain for a supervised
113123
segmentation task to perform this task on a different target domain.
114124
115125
We support different domain adaptation settings:
@@ -177,7 +187,7 @@ def mean_teacher_adaptation(
177187
if os.path.isdir(source_checkpoint):
178188
model = torch_em.util.load_model(source_checkpoint)
179189
else:
180-
model = torch.load(source_checkpoint)
190+
model = torch.load(source_checkpoint, weights_only=False)
181191
reinit_teacher = False
182192

183193
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
@@ -252,3 +262,136 @@ def mean_teacher_adaptation(
252262
sampler=pseudo_label_sampler,
253263
)
254264
trainer.fit(n_iterations)
265+
266+
# TODO patch shapes for other models
267+
PATCH_SHAPES = {
268+
"vesicles_3d": [48, 256, 256],
269+
}
270+
"""@private
271+
"""
272+
273+
def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
274+
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
275+
if len(files) == 0:
276+
raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
277+
278+
# Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
279+
# And resave the volumes.
280+
resave_val_crops = len(files) < 4
281+
282+
# We only resave the data if we resave val crops or resize the training data
283+
resave_data = resave_val_crops or resize_training_data
284+
if not resave_data:
285+
train_paths, val_paths = train_test_split(files, test_size=val_fraction)
286+
return train_paths, val_paths
287+
288+
train_paths, val_paths = [], []
289+
for file_path in files:
290+
file_name = os.path.basename(file_path)
291+
data = open_file(file_path, mode="r")["data"][:]
292+
293+
if resize_training_data:
294+
with mrcfile.open(file_path) as f:
295+
voxel_size = f.voxel_size
296+
voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
297+
scale = compute_scale_from_voxel_size(voxel_size, model_name)
298+
scaler = _Scaler(scale, verbose=False)
299+
data = scaler.sale_input(data)
300+
301+
if resave_val_crops:
302+
n_slices = data.shape[0]
303+
val_slice = int((1.0 - val_fraction) * n_slices)
304+
train_data, val_data = data[:val_slice], data[val_slice:]
305+
306+
train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
307+
with open_file(train_path, mode="w") as f:
308+
f.create_dataset("data", data=train_data, compression="lzf")
309+
train_paths.append(train_path)
310+
311+
val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
312+
with open_file(val_path, mode="w") as f:
313+
f.create_dataset("data", data=val_data, compression="lzf")
314+
val_paths.append(val_path)
315+
316+
else:
317+
output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
318+
with open_file(output_path, mode="w") as f:
319+
f.create_dataset("data", data=data, compression="lzf")
320+
train_paths.append(output_path)
321+
322+
if not resave_val_crops:
323+
train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
324+
325+
return train_paths, val_paths
326+
327+
328+
def _parse_patch_shape(patch_shape, model_name):
329+
if patch_shape is None:
330+
patch_shape = PATCH_SHAPES[model_name]
331+
return patch_shape
332+
333+
def main():
334+
"""@private
335+
"""
336+
import argparse
337+
338+
parser = argparse.ArgumentParser(
339+
description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
340+
"You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
341+
"synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa
342+
"The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa
343+
"You can then use this model for segmentation with the SynapseNet GUI or CLI. "
344+
"Check out the information below for details on the arguments of this function.",
345+
formatter_class=argparse.RawTextHelpFormatter
346+
)
347+
parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
348+
parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
349+
parser.add_argument("--file_pattern", default="*",
350+
help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
351+
parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa
352+
parser.add_argument(
353+
"--source_model",
354+
default="vesicles_3d",
355+
help="The source model used for weight initialization of teacher and student model. "
356+
"By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
357+
)
358+
parser.add_argument(
359+
"--resize_training_data", action="store_true",
360+
help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
361+
)
362+
parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
363+
parser.add_argument(
364+
"--patch_shape", nargs=3, type=int,
365+
help="The patch shape for training. By default the patch shape the source model was trained with is used."
366+
)
367+
368+
# More optional argument:
369+
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
370+
parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa
371+
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
372+
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
373+
parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa
374+
375+
args = parser.parse_args()
376+
377+
source_checkpoint = get_model_path(args.source_model)
378+
patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
379+
with tempfile.TemporaryDirectory() as tmp_dir:
380+
unsupervised_train_paths, unsupervised_val_paths = _get_paths(
381+
args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction,
382+
)
383+
unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
384+
385+
mean_teacher_adaptation(
386+
name=args.name,
387+
unsupervised_train_paths=unsupervised_train_paths,
388+
unsupervised_val_paths=unsupervised_val_paths,
389+
patch_shape=patch_shape,
390+
source_checkpoint=source_checkpoint,
391+
raw_key=raw_key,
392+
n_iterations=args.n_iterations,
393+
batch_size=args.batch_size,
394+
n_samples_train=args.n_samples_train,
395+
n_samples_val=args.n_samples_val,
396+
check=args.check,
397+
)

0 commit comments

Comments
 (0)