|
1 | 1 | import os |
| 2 | +import tempfile |
| 3 | +from glob import glob |
| 4 | +from pathlib import Path |
2 | 5 | from typing import Optional, Tuple |
3 | 6 |
|
| 7 | +import mrcfile |
4 | 8 | import torch |
5 | 9 | import torch_em |
6 | 10 | import torch_em.self_training as self_training |
| 11 | +from elf.io import open_file |
| 12 | +from sklearn.model_selection import train_test_split |
7 | 13 |
|
8 | 14 | from .semisupervised_training import get_unsupervised_loader |
9 | 15 | from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim |
| 16 | +from ..inference.inference import get_model_path, compute_scale_from_voxel_size |
| 17 | +from ..inference.util import _Scaler |
10 | 18 |
|
11 | 19 |
|
12 | 20 | def mean_teacher_adaptation( |
@@ -91,7 +99,7 @@ def mean_teacher_adaptation( |
91 | 99 | if os.path.isdir(source_checkpoint): |
92 | 100 | model = torch_em.util.load_model(source_checkpoint) |
93 | 101 | else: |
94 | | - model = torch.load(source_checkpoint) |
| 102 | + model = torch.load(source_checkpoint, weights_only=False) |
95 | 103 | reinit_teacher = False |
96 | 104 |
|
97 | 105 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
@@ -148,3 +156,109 @@ def mean_teacher_adaptation( |
148 | 156 | sampler=sampler, |
149 | 157 | ) |
150 | 158 | trainer.fit(n_iterations) |
| 159 | + |
| 160 | + |
| 161 | +# TODO patch shapes for other models |
| 162 | +PATCH_SHAPES = { |
| 163 | + "vesicles_3d": [48, 256, 256], |
| 164 | +} |
| 165 | +"""@private |
| 166 | +""" |
| 167 | + |
| 168 | + |
| 169 | +def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir): |
| 170 | + files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) |
| 171 | + if len(files) == 0: |
| 172 | + raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}") |
| 173 | + |
| 174 | + val_fraction = 0.15 |
| 175 | + |
| 176 | + # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation. |
| 177 | + # And resave the volumes. |
| 178 | + resave_val_crops = len(files) < 4 |
| 179 | + |
| 180 | + # We only resave the data if we resave val crops or resize the training data |
| 181 | + resave_data = resave_val_crops or resize_training_data |
| 182 | + if not resave_data: |
| 183 | + train_paths, val_paths = train_test_split(files, test_size=val_fraction) |
| 184 | + return train_paths, val_paths |
| 185 | + |
| 186 | + train_paths, val_paths = [], [] |
| 187 | + for file_path in files: |
| 188 | + file_name = os.path.basename(file_path) |
| 189 | + data = open_file(file_path, mode="r")["data"][:] |
| 190 | + |
| 191 | + if resize_training_data: |
| 192 | + with mrcfile.open(file_path) as f: |
| 193 | + voxel_size = f.voxel_size |
| 194 | + voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())} |
| 195 | + scale = compute_scale_from_voxel_size(voxel_size, model_name) |
| 196 | + scaler = _Scaler(scale, verbose=False) |
| 197 | + data = scaler.sale_input(data) |
| 198 | + |
| 199 | + if resave_val_crops: |
| 200 | + n_slices = data.shape[0] |
| 201 | + val_slice = int((1.0 - val_fraction) * n_slices) |
| 202 | + train_data, val_data = data[:val_slice], data[val_slice:] |
| 203 | + |
| 204 | + train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5") |
| 205 | + with open_file(train_path, mode="w") as f: |
| 206 | + f.create_dataset("data", data=train_data, compression="lzf") |
| 207 | + train_paths.append(train_path) |
| 208 | + |
| 209 | + val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5") |
| 210 | + with open_file(val_path, mode="w") as f: |
| 211 | + f.create_dataset("data", data=val_data, compression="lzf") |
| 212 | + val_paths.append(val_path) |
| 213 | + |
| 214 | + else: |
| 215 | + output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")) |
| 216 | + with open_file(output_path, mode="w") as f: |
| 217 | + f.create_dataset("data", data=data, compression="lzf") |
| 218 | + train_paths.append(output_path) |
| 219 | + |
| 220 | + if not resave_val_crops: |
| 221 | + train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction) |
| 222 | + |
| 223 | + return train_paths, val_paths |
| 224 | + |
| 225 | + |
| 226 | +def _parse_patch_shape(patch_shape, model_name): |
| 227 | + if patch_shape is None: |
| 228 | + patch_shape = PATCH_SHAPES[model_name] |
| 229 | + return patch_shape |
| 230 | + |
| 231 | + |
| 232 | +def main(): |
| 233 | + """@private |
| 234 | + """ |
| 235 | + import argparse |
| 236 | + |
| 237 | + parser = argparse.ArgumentParser( |
| 238 | + description="" |
| 239 | + ) |
| 240 | + parser.add_argument("--name", "-n", required=True) |
| 241 | + parser.add_argument("--input", "-i", required=True) |
| 242 | + parser.add_argument("--pattern", "-p", default="*.mrc") |
| 243 | + parser.add_argument("--source_model", default="vesicles_3d") |
| 244 | + parser.add_argument("--resize_training_data", action="store_true") |
| 245 | + parser.add_argument("--n_iterations", type=int, default=int(1e4)) |
| 246 | + parser.add_argument("--patch_shape", nargs="+", type=int) |
| 247 | + args = parser.parse_args() |
| 248 | + |
| 249 | + source_checkpoint = get_model_path(args.source_model) |
| 250 | + patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) |
| 251 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 252 | + unsupervised_train_paths, unsupervised_val_paths = _get_paths( |
| 253 | + args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir |
| 254 | + ) |
| 255 | + |
| 256 | + mean_teacher_adaptation( |
| 257 | + name=args.name, |
| 258 | + unsupervised_train_paths=unsupervised_train_paths, |
| 259 | + unsupervised_val_paths=unsupervised_val_paths, |
| 260 | + patch_shape=patch_shape, |
| 261 | + source_checkpoint=source_checkpoint, |
| 262 | + raw_key="data", |
| 263 | + n_iterations=args.n_iterations, |
| 264 | + ) |
0 commit comments