Skip to content

Commit 6e865c3

Browse files
committed
implement lamella mask
1 parent 4c17c0d commit 6e865c3

File tree

2 files changed

+26
-168
lines changed

2 files changed

+26
-168
lines changed
Lines changed: 17 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,12 @@
11
import os
2-
import tempfile
3-
from glob import glob
4-
from pathlib import Path
52
from typing import Optional, Tuple
63

7-
import mrcfile
84
import torch
95
import torch_em
106
import torch_em.self_training as self_training
11-
from elf.io import open_file
12-
from sklearn.model_selection import train_test_split
137

148
from .semisupervised_training import get_unsupervised_loader
15-
from .supervised_training import (
16-
get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
17-
)
18-
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
19-
from ..inference.util import _Scaler
9+
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
2010

2111
def mean_teacher_adaptation(
2212
name: str,
@@ -38,10 +28,11 @@ def mean_teacher_adaptation(
3828
n_samples_val: Optional[int] = None,
3929
train_mask_paths: Optional[Tuple[str]] = None,
4030
val_mask_paths: Optional[Tuple[str]] = None,
41-
sampler: Optional[callable] = None,
31+
patch_sampler: Optional[callable] = None,
32+
pseudo_label_sampler: Optional[callable] = None,
4233
device: int = 0,
4334
) -> None:
44-
"""Run domain adapation to transfer a network trained on a source domain for a supervised
35+
"""Run domain adaptation to transfer a network trained on a source domain for a supervised
4536
segmentation task to perform this task on a different target domain.
4637
4738
We support different domain adaptation settings:
@@ -84,10 +75,11 @@ def mean_teacher_adaptation(
8475
based on the patch_shape and size of the volumes used for training.
8576
n_samples_val: The number of val samples per epoch. By default this will be estimated
8677
based on the patch_shape and size of the volumes used for validation.
87-
train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training.
88-
val_mask_paths: Boundary masks used by the sampler to accept or reject patches for validation.
89-
sampler: Accept or reject patches based on a condition.
90-
device: GPU ID for training.
78+
train_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.
79+
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
80+
patch_sampler: Accept or reject patches based on a condition.
81+
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
82+
device: GPU ID for training.
9183
"""
9284
assert (supervised_train_paths is None) == (supervised_val_paths is None)
9385
is_2d, _ = _determine_ndim(patch_shape)
@@ -103,11 +95,11 @@ def mean_teacher_adaptation(
10395
model = get_3d_model(out_channels=2)
10496
reinit_teacher = True
10597
else:
106-
print("Mean teacehr training initialized from source model:", source_checkpoint)
98+
print("Mean teacher training initialized from source model:", source_checkpoint)
10799
if os.path.isdir(source_checkpoint):
108100
model = torch_em.util.load_model(source_checkpoint)
109101
else:
110-
model = torch.load(source_checkpoint, weights_only=False)
102+
model = torch.load(source_checkpoint)
111103
reinit_teacher = False
112104

113105
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
@@ -117,23 +109,24 @@ def mean_teacher_adaptation(
117109
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
118110
loss = self_training.DefaultSelfTrainingLoss()
119111
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
120-
112+
121113
unsupervised_train_loader = get_unsupervised_loader(
122114
data_paths=unsupervised_train_paths,
123115
raw_key=raw_key,
124116
patch_shape=patch_shape,
125117
batch_size=batch_size,
126118
n_samples=n_samples_train,
127-
boundary_mask_paths=train_mask_paths,
128-
sampler=sampler
119+
sample_mask_paths=train_mask_paths,
120+
sampler=patch_sampler
129121
)
130122
unsupervised_val_loader = get_unsupervised_loader(
131123
data_paths=unsupervised_val_paths,
132124
raw_key=raw_key,
133125
patch_shape=patch_shape,
134126
batch_size=batch_size,
135127
n_samples=n_samples_val,
136-
boundary_mask_paths=val_mask_paths, sampler=sampler
128+
sample_mask_paths=val_mask_paths,
129+
sampler=patch_sampler
137130
)
138131

139132
if supervised_train_paths is not None:
@@ -172,142 +165,6 @@ def mean_teacher_adaptation(
172165
device=device,
173166
reinit_teacher=reinit_teacher,
174167
save_root=save_root,
175-
sampler=None, # TODO currently set to none cause I didn't want to pass the same sampler used by get_unsupervised_loader
168+
sampler=pseudo_label_sampler,
176169
)
177170
trainer.fit(n_iterations)
178-
179-
180-
# TODO patch shapes for other models
181-
PATCH_SHAPES = {
182-
"vesicles_3d": [48, 256, 256],
183-
}
184-
"""@private
185-
"""
186-
187-
188-
def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
189-
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
190-
if len(files) == 0:
191-
raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
192-
193-
# Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
194-
# And resave the volumes.
195-
resave_val_crops = len(files) < 4
196-
197-
# We only resave the data if we resave val crops or resize the training data
198-
resave_data = resave_val_crops or resize_training_data
199-
if not resave_data:
200-
train_paths, val_paths = train_test_split(files, test_size=val_fraction)
201-
return train_paths, val_paths
202-
203-
train_paths, val_paths = [], []
204-
for file_path in files:
205-
file_name = os.path.basename(file_path)
206-
data = open_file(file_path, mode="r")["data"][:]
207-
208-
if resize_training_data:
209-
with mrcfile.open(file_path) as f:
210-
voxel_size = f.voxel_size
211-
voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
212-
scale = compute_scale_from_voxel_size(voxel_size, model_name)
213-
scaler = _Scaler(scale, verbose=False)
214-
data = scaler.sale_input(data)
215-
216-
if resave_val_crops:
217-
n_slices = data.shape[0]
218-
val_slice = int((1.0 - val_fraction) * n_slices)
219-
train_data, val_data = data[:val_slice], data[val_slice:]
220-
221-
train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
222-
with open_file(train_path, mode="w") as f:
223-
f.create_dataset("data", data=train_data, compression="lzf")
224-
train_paths.append(train_path)
225-
226-
val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
227-
with open_file(val_path, mode="w") as f:
228-
f.create_dataset("data", data=val_data, compression="lzf")
229-
val_paths.append(val_path)
230-
231-
else:
232-
output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
233-
with open_file(output_path, mode="w") as f:
234-
f.create_dataset("data", data=data, compression="lzf")
235-
train_paths.append(output_path)
236-
237-
if not resave_val_crops:
238-
train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
239-
240-
return train_paths, val_paths
241-
242-
243-
def _parse_patch_shape(patch_shape, model_name):
244-
if patch_shape is None:
245-
patch_shape = PATCH_SHAPES[model_name]
246-
return patch_shape
247-
248-
249-
def main():
250-
"""@private
251-
"""
252-
import argparse
253-
254-
parser = argparse.ArgumentParser(
255-
description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
256-
"You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
257-
"synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa
258-
"The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa
259-
"You can then use this model for segmentation with the SynapseNet GUI or CLI. "
260-
"Check out the information below for details on the arguments of this function.",
261-
formatter_class=argparse.RawTextHelpFormatter
262-
)
263-
parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
264-
parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
265-
parser.add_argument("--file_pattern", default="*",
266-
help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
267-
parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa
268-
parser.add_argument(
269-
"--source_model",
270-
default="vesicles_3d",
271-
help="The source model used for weight initialization of teacher and student model. "
272-
"By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
273-
)
274-
parser.add_argument(
275-
"--resize_training_data", action="store_true",
276-
help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
277-
)
278-
parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
279-
parser.add_argument(
280-
"--patch_shape", nargs=3, type=int,
281-
help="The patch shape for training. By default the patch shape the source model was trained with is used."
282-
)
283-
284-
# More optional argument:
285-
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
286-
parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa
287-
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
288-
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
289-
parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa
290-
291-
args = parser.parse_args()
292-
293-
source_checkpoint = get_model_path(args.source_model)
294-
patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
295-
with tempfile.TemporaryDirectory() as tmp_dir:
296-
unsupervised_train_paths, unsupervised_val_paths = _get_paths(
297-
args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction,
298-
)
299-
unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
300-
301-
mean_teacher_adaptation(
302-
name=args.name,
303-
unsupervised_train_paths=unsupervised_train_paths,
304-
unsupervised_val_paths=unsupervised_val_paths,
305-
patch_shape=patch_shape,
306-
source_checkpoint=source_checkpoint,
307-
raw_key=raw_key,
308-
n_iterations=args.n_iterations,
309-
batch_size=args.batch_size,
310-
n_samples_train=args.n_samples_train,
311-
n_samples_val=args.n_samples_val,
312-
check=args.check,
313-
)

synapse_net/training/semisupervised_training.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ def __init__(self, sampler):
5151
def __call__(self, x):
5252
raw, mask = x[0], x[1]
5353
return self.sampler(raw, mask)
54-
54+
5555
def get_unsupervised_loader(
5656
data_paths: Tuple[str],
5757
raw_key: str,
5858
patch_shape: Tuple[int, int, int],
5959
batch_size: int,
6060
n_samples: Optional[int],
61-
boundary_mask_paths: Optional[Tuple[str]] = None,
61+
sample_mask_paths: Optional[Tuple[str]] = None,
6262
sampler: Optional[callable] = None,
63-
exclude_top_and_bottom: bool = False, # TODO this seems unneccesary if we have a boundary mask - remove?
63+
exclude_top_and_bottom: bool = False,
6464
) -> torch.utils.data.DataLoader:
6565
"""Get a dataloader for unsupervised segmentation training.
6666
@@ -75,7 +75,7 @@ def get_unsupervised_loader(
7575
based on the patch_shape and size of the volumes used for training.
7676
exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
7777
avoid artifacts at the border of tomograms.
78-
boundary_mask_paths: The filepaths to the corresponding boundary masks for each tomogram.
78+
sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
7979
sampler: Accept or reject patches based on a condition.
8080
8181
Returns:
@@ -88,12 +88,12 @@ def get_unsupervised_loader(
8888
else:
8989
roi = None
9090
# stack tomograms and masks and write to temp files to use as input to RawDataset()
91-
if boundary_mask_paths is not None:
92-
assert len(data_paths) == len(boundary_mask_paths), \
93-
f"Expected equal number of data_paths and and boundary_masks_paths, got {len(data_paths)} data paths and {len(boundary_mask_paths)} mask paths."
91+
if sample_mask_paths is not None:
92+
assert len(data_paths) == len(sample_mask_paths), \
93+
f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths."
9494

9595
stacked_paths = []
96-
for i, (data_path, mask_path) in enumerate(zip(data_paths, boundary_mask_paths)):
96+
for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)):
9797
raw = read_mrc(data_path)[0]
9898
mask = read_mrc(mask_path)[0]
9999
stacked = np.stack([raw, mask], axis=0)
@@ -136,6 +136,7 @@ def get_unsupervised_loader(
136136
num_workers=num_workers, shuffle=True)
137137
return loader
138138

139+
139140
# TODO: use different paths for supervised and unsupervised training
140141
# (We are currently not using this functionality directly, so this is not a high priority)
141142
def semisupervised_training(

0 commit comments

Comments
 (0)