Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions synapse_net/training/domain_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
from ..inference.util import _Scaler


def mean_teacher_adaptation(
name: str,
unsupervised_train_paths: Tuple[str],
Expand All @@ -37,7 +36,10 @@ def mean_teacher_adaptation(
n_iterations: int = int(1e4),
n_samples_train: Optional[int] = None,
n_samples_val: Optional[int] = None,
train_mask_paths: Optional[Tuple[str]] = None,
val_mask_paths: Optional[Tuple[str]] = None,
sampler: Optional[callable] = None,
device: int = 0,
) -> None:
"""Run domain adapation to transfer a network trained on a source domain for a supervised
segmentation task to perform this task on a different target domain.
Expand Down Expand Up @@ -82,6 +84,10 @@ def mean_teacher_adaptation(
based on the patch_shape and size of the volumes used for training.
n_samples_val: The number of val samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for validation.
train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor semantic comment: I think that this is no necessarily a boundary mask. I think just calling it mask is more precise.

Copy link
Contributor Author

@stmartineau99 stmartineau99 Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jonathan's lamella masker uses the term "boundary mask" so that is why I used it. It makes sense because the mask defines the boundary of the signal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using three different masks in this pipeline now (gradient mask, boundary mask, membrane mask) we need to have different words to describe them. Correct me if there is a more clear way to refer to it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the context here:

  • The "gradient mask" is computed internally only, so we don't need to expose parameters related to it here. But if you want to refer to it for some explanations then calling it "gradient mask" is good.
  • "boundary mask" I would call different, as we use this for accepting / rejecting samples. It does not necessarily have to be on the (spatial) boundary. (And I find the 'boundary of the signal' notion not so intuitive). I would call it "sample mask".
  • I would call the other mask, which you called "membrane mask", "background mask", as we use it to enforce background label in the pseudo labels. In our case this is indeed for membranes, but it could also be for other structures.

val_mask_paths: Boundary masks used by the sampler to accept or reject patches for validation.
sampler: Accept or reject patches based on a condition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The samplers for the datasets and the mean teacher trainer have slightly different meaning. See also comment below. I think the best approach here is to expose and document two different sampler arguments:

  • patch_sampler: is passed to get_unsupervised_loader
  • pseudo_label_sampler: is passed to MeanTeacherTrainer

Feel free to suggest better names ;).

device: GPU ID for training.
"""
assert (supervised_train_paths is None) == (supervised_val_paths is None)
is_2d, _ = _determine_ndim(patch_shape)
Expand Down Expand Up @@ -113,10 +119,21 @@ def mean_teacher_adaptation(
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()

unsupervised_train_loader = get_unsupervised_loader(
unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
data_paths=unsupervised_train_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_train,
boundary_mask_paths=train_mask_paths,
sampler=sampler
)
unsupervised_val_loader = get_unsupervised_loader(
unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
data_paths=unsupervised_val_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_val,
boundary_mask_paths=val_mask_paths, sampler=sampler
)

if supervised_train_paths is not None:
Expand All @@ -133,7 +150,7 @@ def mean_teacher_adaptation(
supervised_train_loader = None
supervised_val_loader = None

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
trainer = self_training.MeanTeacherTrainer(
name=name,
model=model,
Expand All @@ -155,7 +172,7 @@ def mean_teacher_adaptation(
device=device,
reinit_teacher=reinit_teacher,
save_root=save_root,
sampler=sampler,
sampler=None, # TODO currently set to none cause I didn't want to pass the same sampler used by get_unsupervised_loader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampler here is applied to the pseudo-labels predicted by the teacher, to give a criterion for rejecting pseudo labels. In contrast, the sampler passed to the loaders rejects patches based on some criterion applied to the data. It makes sense to support both and to pass them with different names; see comment above.

)
trainer.fit(n_iterations)

Expand Down
69 changes: 61 additions & 8 deletions synapse_net/training/semisupervised_training.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Optional, Tuple

import numpy as np
import uuid
import h5py
import torch
import torch_em
import torch_em.self_training as self_training
from torchvision import transforms

from synapse_net.file_utils import read_mrc
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim


Expand All @@ -28,14 +31,36 @@ def weak_augmentations(p: float = 0.75) -> callable:
])
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)


def drop_mask_channel(x):
x = x[:1]
return x

class ComposedTransform:
def __init__(self, *funcs):
self.funcs = funcs

def __call__(self, x):
for f in self.funcs:
x = f(x)
return x

class ChannelSplitterSampler:
def __init__(self, sampler):
self.sampler = sampler

def __call__(self, x):
raw, mask = x[0], x[1]
return self.sampler(raw, mask)

def get_unsupervised_loader(
data_paths: Tuple[str],
raw_key: str,
patch_shape: Tuple[int, int, int],
batch_size: int,
n_samples: Optional[int],
exclude_top_and_bottom: bool = False,
boundary_mask_paths: Optional[Tuple[str]] = None,
sampler: Optional[callable] = None,
exclude_top_and_bottom: bool = False, # TODO this seems unneccesary if we have a boundary mask - remove?
) -> torch.utils.data.DataLoader:
"""Get a dataloader for unsupervised segmentation training.
Expand All @@ -50,19 +75,46 @@ def get_unsupervised_loader(
based on the patch_shape and size of the volumes used for training.
exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
avoid artifacts at the border of tomograms.
boundary_mask_paths: The filepaths to the corresponding boundary masks for each tomogram.
sampler: Accept or reject patches based on a condition.
Returns:
The PyTorch dataloader.
"""

# We exclude the top and bottom slices where the tomogram reconstruction is bad.
# TODO this seems unneccesary if we have a boundary mask - remove?
if exclude_top_and_bottom:
roi = np.s_[5:-5, :, :]
else:
roi = None
# stack tomograms and masks and write to temp files to use as input to RawDataset()
if boundary_mask_paths is not None:
assert len(data_paths) == len(boundary_mask_paths), \
f"Expected equal number of data_paths and and boundary_masks_paths, got {len(data_paths)} data paths and {len(boundary_mask_paths)} mask paths."

stacked_paths = []
for i, (data_path, mask_path) in enumerate(zip(data_paths, boundary_mask_paths)):
raw = read_mrc(data_path)[0]
mask = read_mrc(mask_path)[0]
stacked = np.stack([raw, mask], axis=0)

tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5"
with h5py.File(tmp_path, "w") as f:
f.create_dataset("raw", data=stacked, compression="gzip")
stacked_paths.append(tmp_path)

# update variables for RawDataset()
data_paths = tuple(stacked_paths)
base_transform = torch_em.transform.get_raw_transform()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be adapted to only act on channel 0 (the actual raw data.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this for the next PR.

raw_transform = ComposedTransform(base_transform, drop_mask_channel)
sampler = ChannelSplitterSampler(sampler)
with_channels = True
else:
raw_transform = torch_em.transform.get_raw_transform()
with_channels = False
sampler = None

_, ndim = _determine_ndim(patch_shape)
raw_transform = torch_em.transform.get_raw_transform()
transform = torch_em.transform.get_augmentations(ndim=ndim)

if n_samples is None:
Expand All @@ -71,18 +123,19 @@ def get_unsupervised_loader(
n_samples_per_ds = int(n_samples / len(data_paths))

augmentations = (weak_augmentations(), weak_augmentations())

datasets = [
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds)
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi,
n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations)
for path in data_paths
]
ds = torch.utils.data.ConcatDataset(datasets)

num_workers = 4 * batch_size
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size,
num_workers=num_workers, shuffle=True)
return loader


# TODO: use different paths for supervised and unsupervised training
# (We are currently not using this functionality directly, so this is not a high priority)
def semisupervised_training(
Expand Down
Loading