Skip to content

Commit 7ebd59d

Browse files
committed
create subclass NewMeanTeacherTrainer
1 parent aa8cd15 commit 7ebd59d

File tree

1 file changed

+127
-27
lines changed

1 file changed

+127
-27
lines changed

synapse_net/training/domain_adaptation.py

Lines changed: 127 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import tempfile
33
from glob import glob
44
from pathlib import Path
5-
from typing import Optional, Tuple
5+
from typing import Optional, Tuple, Callable
6+
import time
67

78
import mrcfile
89
import torch
910
import torch_em
1011
import torch_em.self_training as self_training
12+
from torch_em.self_training.logger import SelfTrainingTensorboardLogger
1113
from elf.io import open_file
1214
from sklearn.model_selection import train_test_split
1315

@@ -19,8 +21,8 @@
1921
from ..inference.util import _Scaler
2022

2123
class NewPseudoLabeler(self_training.DefaultPseudoLabeler):
22-
"""Compute pseudo labels based on model predictions, typically from a teacher model.
23-
By default, assumes that the first channel contains the transformed data and the second channel contains the background mask. # TODO update description
24+
"""Subclass of DefaultPseudoLabeler, which can subtract background from the pseudo labels if a background mask is provided.
25+
By default, assumes that the first channel contains the transformed raw data and the second channel contains the background mask.
2426
2527
Args:
2628
activation: Activation function applied to the teacher prediction.
@@ -32,23 +34,23 @@ class NewPseudoLabeler(self_training.DefaultPseudoLabeler):
3234
confidence_mask_channel: A specific channel to use for computing the confidence mask.
3335
By default the confidence mask is computed across all channels independently.
3436
This is useful, if only one of the channels encodes a probability.
35-
raw_channel: # TODO add description
36-
background_mask_channel: # TODO add description
37+
raw_channel: Channel index of the raw data, which will be used as input to the teacher model
38+
background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels.
3739
"""
3840
def __init__(
3941
self,
4042
activation: Optional[torch.nn.Module] = None,
4143
confidence_threshold: Optional[float] = None,
4244
threshold_from_both_sides: bool = True,
4345
confidence_mask_channel: Optional[int] = None,
44-
raw_channel: Optional[int] = 0,
46+
raw_channel: Optional[int] = 0,
4547
background_mask_channel: Optional[int] = 1,
4648
):
4749
super().__init__(activation, confidence_threshold, threshold_from_both_sides)
50+
self.confidence_mask_channel = confidence_mask_channel
4851
self.raw_channel = raw_channel
4952
self.background_mask_channel = background_mask_channel
50-
self.confidence_mask_channel = confidence_mask_channel
51-
53+
5254
def _subtract_background(self, pseudo_labels: torch.Tensor, background_mask: torch.Tensor):
5355
bool_mask = background_mask.bool()
5456
return pseudo_labels.masked_fill(bool_mask, 0)
@@ -63,10 +65,12 @@ def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tens
6365
Returns:
6466
The pseudo-labels.
6567
"""
66-
if self.background_mask_channel is not None:
67-
if input_.ndim != 5:
68-
raise ValueError(f"Expect data with 5 dimensions (B, C, D, H, W), got shape {input_.shape}.")
69-
68+
if input_.ndim != 5:
69+
raise ValueError(f"Expect data with 5 dimensions (B, C, D, H, W), got shape {input_.shape}.")
70+
71+
has_background_mask = input_.shape[1] > 1
72+
73+
if has_background_mask:
7074
if self.background_mask_channel > input_.shape[1]:
7175
raise ValueError(f"Channel index {self.background_mask_channel} is out of bounds for shape {input_.shape}.")
7276

@@ -88,11 +92,112 @@ def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tens
8892
size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2)))
8993
label_mask = label_mask.expand(*size)
9094

91-
if self.background_mask_channel is not None:
95+
if has_background_mask:
9296
pseudo_labels = self._subtract_background(pseudo_labels, background_mask)
9397

9498
return pseudo_labels, label_mask
9599

100+
class NewMeanTeacherTrainer(self_training.MeanTeacherTrainer):
101+
"""Subclass of MeanTeacherTrainer, updated to handle cases where the background mask is provided.
102+
Once the pseudo labels are computed, the second channel of the teacher input is dropped, if it exists.
103+
The second channel of the student input is also dropped, if it exists, since it is not needed for training.
104+
105+
Args:
106+
activation: Activation function applied to the teacher prediction.
107+
confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
108+
If None is given no mask will be computed.
109+
threshold_from_both_sides: Whether to include both values bigger than the threshold
110+
and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
111+
The former should be used for binary labels, the latter for for multiclass labels.
112+
confidence_mask_channel: A specific channel to use for computing the confidence mask.
113+
By default the confidence mask is computed across all channels independently.
114+
This is useful, if only one of the channels encodes a probability.
115+
raw_channel: Channel index of the raw data to be used as input to the teacher model.
116+
background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels.
117+
"""
118+
def __init__(
119+
self,
120+
model: torch.nn.Module,
121+
unsupervised_train_loader: torch.utils.data.DataLoader,
122+
unsupervised_loss: Callable,
123+
pseudo_labeler: Callable,
124+
supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
125+
unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
126+
supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
127+
supervised_loss: Optional[Callable] = None,
128+
unsupervised_loss_and_metric: Optional[Callable] = None,
129+
supervised_loss_and_metric: Optional[Callable] = None,
130+
logger=SelfTrainingTensorboardLogger,
131+
momentum: float = 0.999,
132+
reinit_teacher: Optional[bool] = None,
133+
sampler: Optional[Callable] = None,
134+
**kwargs,
135+
):
136+
super().__init__(model, unsupervised_train_loader, unsupervised_loss, pseudo_labeler,
137+
supervised_train_loader, unsupervised_val_loader, supervised_val_loader,
138+
supervised_loss, unsupervised_loss_and_metric, supervised_loss_and_metric,
139+
logger, momentum, reinit_teacher, sampler, **kwargs)
140+
141+
def _train_epoch_unsupervised(self, progress, forward_context, backprop):
142+
self.model.train()
143+
144+
n_iter = 0
145+
t_per_iter = time.time()
146+
147+
# Sample from both the supervised and unsupervised loader.
148+
for xu1, xu2 in self.unsupervised_train_loader:
149+
150+
# Assuming shape (B, C, D, H, W), only keep the first channel for xu2 (student input).
151+
if xu2.shape[1] > 1:
152+
xu2 = xu2[:, :1].contiguous()
153+
154+
xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)
155+
156+
teacher_input, model_input = xu1, xu2
157+
158+
with forward_context(), torch.no_grad():
159+
# Compute the pseudo labels.
160+
pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
161+
162+
# Drop the second channel for xu1 (teacher input) after computing the pseudo labels.
163+
if xu1.shape[1] > 1:
164+
xu1 = xu1[:, :1].contiguous()
165+
166+
# If we have a sampler then check if the current batch matches the condition for inclusion in training.
167+
if self.sampler is not None:
168+
keep_batch = self.sampler(pseudo_labels, label_filter)
169+
if not keep_batch:
170+
continue
171+
172+
self.optimizer.zero_grad()
173+
# Perform unsupervised training
174+
with forward_context():
175+
loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
176+
backprop(loss)
177+
178+
if self.logger is not None:
179+
with torch.no_grad(), forward_context():
180+
pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
181+
self.logger.log_train_unsupervised(
182+
self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
183+
)
184+
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
185+
self.logger.log_lr(self._iteration, lr)
186+
if self.pseudo_labeler.confidence_threshold is not None:
187+
self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)
188+
189+
with torch.no_grad():
190+
self._momentum_update()
191+
192+
self._iteration += 1
193+
n_iter += 1
194+
if self._iteration >= self.max_iteration:
195+
break
196+
progress.update(1)
197+
198+
t_per_iter = (time.time() - t_per_iter) / n_iter
199+
return t_per_iter
200+
96201
def mean_teacher_adaptation(
97202
name: str,
98203
unsupervised_train_paths: Tuple[str],
@@ -114,13 +219,11 @@ def mean_teacher_adaptation(
114219
train_sample_mask_paths: Optional[Tuple[str]] = None,
115220
val_sample_mask_paths: Optional[Tuple[str]] = None,
116221
train_background_mask_paths: Optional[Tuple[str]] = None,
117-
train_mask_paths: Optional[Tuple[str]] = None,
118-
val_mask_paths: Optional[Tuple[str]] = None,
119222
patch_sampler: Optional[callable] = None,
120223
pseudo_label_sampler: Optional[callable] = None,
121224
device: int = 0,
122225
) -> None:
123-
"""Run domain adaptation to transfer a network trained on a source domain for a supervised
226+
"""Run domain adapation to transfer a network trained on a source domain for a supervised
124227
segmentation task to perform this task on a different target domain.
125228
126229
We support different domain adaptation settings:
@@ -163,15 +266,14 @@ def mean_teacher_adaptation(
163266
based on the patch_shape and size of the volumes used for training.
164267
n_samples_val: The number of val samples per epoch. By default this will be estimated
165268
based on the patch_shape and size of the volumes used for validation.
166-
train_sample_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.
167-
val_sample_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
168-
train_background_mask_paths: # TODO add description
269+
train_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject
270+
patches for training.
271+
val_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject
272+
patches for validation.
273+
train_background_mask_paths: Filepaths to the background masks used for training.
274+
Background masks are used to subtract background from the pseudo labels before the forward pass.
169275
patch_sampler: A sampler for rejecting patches based on a defined conditon.
170276
pseudo_label_sampler: A sampler for rejecting pseudo-labels based on a defined condition.
171-
train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
172-
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
173-
patch_sampler: Accept or reject patches based on a condition.
174-
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
175277
device: GPU ID for training.
176278
"""
177279
assert (supervised_train_paths is None) == (supervised_val_paths is None)
@@ -192,7 +294,7 @@ def mean_teacher_adaptation(
192294
if os.path.isdir(source_checkpoint):
193295
model = torch_em.util.load_model(source_checkpoint)
194296
else:
195-
model = torch.load(source_checkpoint, weights_only=False)
297+
model = torch.load(source_checkpoint)
196298
reinit_teacher = False
197299

198300
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
@@ -206,7 +308,7 @@ def mean_teacher_adaptation(
206308

207309
loss = self_training.DefaultSelfTrainingLoss()
208310
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
209-
311+
210312
unsupervised_train_loader = get_unsupervised_loader(
211313
data_paths=unsupervised_train_paths,
212314
raw_key=raw_key,
@@ -215,7 +317,6 @@ def mean_teacher_adaptation(
215317
n_samples=n_samples_train,
216318
sample_mask_paths=train_sample_mask_paths,
217319
background_mask_paths=train_background_mask_paths,
218-
sample_mask_paths=train_mask_paths,
219320
sampler=patch_sampler
220321
)
221322
unsupervised_val_loader = get_unsupervised_loader(
@@ -226,7 +327,6 @@ def mean_teacher_adaptation(
226327
n_samples=n_samples_val,
227328
sample_mask_paths=val_sample_mask_paths,
228329
background_mask_paths=None,
229-
sample_mask_paths=val_mask_paths,
230330
sampler=patch_sampler
231331
)
232332

0 commit comments

Comments
 (0)