Skip to content

Commit be0917a

Browse files
Fix issues in training CLI and add domain adaptation CLI
1 parent 70628f6 commit be0917a

File tree

5 files changed

+125
-10
lines changed

5 files changed

+125
-10
lines changed

scripts/cooper/revision/az_prediction.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def run_prediction(model, name, split_folder, version, split_names, in_path):
2424

2525
for fname in tqdm(file_names):
2626
if in_path:
27-
input_path=os.path.join(in_path, name, fname)
27+
input_path = os.path.join(in_path, name, fname)
2828
else:
2929
input_path = os.path.join(INPUT_ROOT, name, fname)
3030
print(f"segmenting {input_path}")
@@ -50,15 +50,14 @@ def run_prediction(model, name, split_folder, version, split_names, in_path):
5050
print(f"{output_key_seg} already saved")
5151
else:
5252
f.create_dataset(output_key_seg, data=seg, compression="lzf")
53-
5453

5554

5655
def get_model(version):
5756
assert version in (3, 4, 5, 6, 7)
5857
split_folder = get_split_folder(version)
5958
if version == 3:
6059
model_path = os.path.join(split_folder, "checkpoints", "3D-AZ-model-TEM_STEM_ChemFix_wichmann-v3")
61-
elif version ==6:
60+
elif version == 6:
6261
model_path = "/mnt/ceph-hdd/cold/nim00007/models/AZ/v6/"
6362
elif version == 7:
6463
model_path = "/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/ConstantinAZ/checkpoints/v7/"
@@ -79,15 +78,15 @@ def main():
7978
args = parser.parse_args()
8079

8180
if args.model_path:
82-
model = load_model(model_path)
81+
model = load_model(args.model_path)
8382
else:
8483
model = get_model(args.version)
8584

8685
split_folder = get_split_folder(args.version)
8786

8887
for name in args.names:
8988
run_prediction(model, name, split_folder, args.version, args.splits, args.input)
90-
89+
9190
print("Finished segmenting!")
9291

9392

scripts/cooper/revision/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_split_folder(version):
6565
if version == 3:
6666
split_folder = "splits"
6767
elif version == 6:
68-
split_folder= "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/splits"
68+
split_folder = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/splits"
6969
else:
7070
split_folder = "models_az_thin"
7171
return split_folder

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"synapse_net.export_to_imod_points = synapse_net.tools.cli:imod_point_cli",
1818
"synapse_net.export_to_imod_objects = synapse_net.tools.cli:imod_object_cli",
1919
"synapse_net.run_supervised_training = synapse_net.training.supervised_training:main",
20+
"synapse_net.run_domain_adaptation = synapse_net.training.domain_adaptation:main",
2021
],
2122
"napari.manifest": [
2223
"synapse_net = synapse_net:napari.yaml",

synapse_net/training/domain_adaptation.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
import os
2+
import tempfile
3+
from glob import glob
4+
from pathlib import Path
25
from typing import Optional, Tuple
36

7+
import mrcfile
48
import torch
59
import torch_em
610
import torch_em.self_training as self_training
11+
from elf.io import open_file
12+
from sklearn.model_selection import train_test_split
713

814
from .semisupervised_training import get_unsupervised_loader
915
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
16+
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
17+
from ..inference.util import _Scaler
1018

1119

1220
def mean_teacher_adaptation(
@@ -91,7 +99,7 @@ def mean_teacher_adaptation(
9199
if os.path.isdir(source_checkpoint):
92100
model = torch_em.util.load_model(source_checkpoint)
93101
else:
94-
model = torch.load(source_checkpoint)
102+
model = torch.load(source_checkpoint, weights_only=False)
95103
reinit_teacher = False
96104

97105
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
@@ -148,3 +156,109 @@ def mean_teacher_adaptation(
148156
sampler=sampler,
149157
)
150158
trainer.fit(n_iterations)
159+
160+
161+
# TODO patch shapes for other models
162+
PATCH_SHAPES = {
163+
"vesicles_3d": [48, 256, 256],
164+
}
165+
"""@private
166+
"""
167+
168+
169+
def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir):
170+
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
171+
if len(files) == 0:
172+
raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
173+
174+
val_fraction = 0.15
175+
176+
# Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
177+
# And resave the volumes.
178+
resave_val_crops = len(files) < 4
179+
180+
# We only resave the data if we resave val crops or resize the training data
181+
resave_data = resave_val_crops or resize_training_data
182+
if not resave_data:
183+
train_paths, val_paths = train_test_split(files, test_size=val_fraction)
184+
return train_paths, val_paths
185+
186+
train_paths, val_paths = [], []
187+
for file_path in files:
188+
file_name = os.path.basename(file_path)
189+
data = open_file(file_path, mode="r")["data"][:]
190+
191+
if resize_training_data:
192+
with mrcfile.open(file_path) as f:
193+
voxel_size = f.voxel_size
194+
voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
195+
scale = compute_scale_from_voxel_size(voxel_size, model_name)
196+
scaler = _Scaler(scale, verbose=False)
197+
data = scaler.sale_input(data)
198+
199+
if resave_val_crops:
200+
n_slices = data.shape[0]
201+
val_slice = int((1.0 - val_fraction) * n_slices)
202+
train_data, val_data = data[:val_slice], data[val_slice:]
203+
204+
train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
205+
with open_file(train_path, mode="w") as f:
206+
f.create_dataset("data", data=train_data, compression="lzf")
207+
train_paths.append(train_path)
208+
209+
val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
210+
with open_file(val_path, mode="w") as f:
211+
f.create_dataset("data", data=val_data, compression="lzf")
212+
val_paths.append(val_path)
213+
214+
else:
215+
output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
216+
with open_file(output_path, mode="w") as f:
217+
f.create_dataset("data", data=data, compression="lzf")
218+
train_paths.append(output_path)
219+
220+
if not resave_val_crops:
221+
train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
222+
223+
return train_paths, val_paths
224+
225+
226+
def _parse_patch_shape(patch_shape, model_name):
227+
if patch_shape is None:
228+
patch_shape = PATCH_SHAPES[model_name]
229+
return patch_shape
230+
231+
232+
def main():
233+
"""@private
234+
"""
235+
import argparse
236+
237+
parser = argparse.ArgumentParser(
238+
description=""
239+
)
240+
parser.add_argument("--name", "-n", required=True)
241+
parser.add_argument("--input", "-i", required=True)
242+
parser.add_argument("--pattern", "-p", default="*.mrc")
243+
parser.add_argument("--source_model", default="vesicles_3d")
244+
parser.add_argument("--resize_training_data", action="store_true")
245+
parser.add_argument("--n_iterations", type=int, default=int(1e4))
246+
parser.add_argument("--patch_shape", nargs="+", type=int)
247+
args = parser.parse_args()
248+
249+
source_checkpoint = get_model_path(args.source_model)
250+
patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
251+
with tempfile.TemporaryDirectory() as tmp_dir:
252+
unsupervised_train_paths, unsupervised_val_paths = _get_paths(
253+
args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir
254+
)
255+
256+
mean_teacher_adaptation(
257+
name=args.name,
258+
unsupervised_train_paths=unsupervised_train_paths,
259+
unsupervised_val_paths=unsupervised_val_paths,
260+
patch_shape=patch_shape,
261+
source_checkpoint=source_checkpoint,
262+
raw_key="data",
263+
n_iterations=args.n_iterations,
264+
)

synapse_net/training/supervised_training.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,9 @@ def supervised_training(
307307

308308

309309
def _parse_input_folder(folder, pattern, key):
310-
files = sorted(glob(os.path.join(folder, "**", pattern)))
310+
files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True))
311311
# Get all file extensions (general wild-cards may pick up files with multiple extensions).
312-
extensions = [os.path.splitext(ff)[1] for ff in files]
312+
extensions = list(set([os.path.splitext(ff)[1] for ff in files]))
313313

314314
# If we have more than 1 file extension we just use the key that was passed,
315315
# as it is unclear how to derive a consistent key.
@@ -372,7 +372,7 @@ def main():
372372
parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.")
373373

374374
# Folders with training data, containing raw/image data and labels.
375-
parser.add_argument("--i", "--train_folder", required=True, help="The input folder with the training image data.")
375+
parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.")
376376
parser.add_argument("--image_file_pattern", default="*",
377377
help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.")
378378
parser.add_argument("--raw_key",
@@ -394,6 +394,7 @@ def main():
394394
parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa
395395
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
396396
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
397+
parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa
397398
args = parser.parse_args()
398399

399400
train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\

0 commit comments

Comments
 (0)