22import  tempfile 
33from  glob  import  glob 
44from  pathlib  import  Path 
5- from  typing  import  Optional , Tuple 
5+ from  typing  import  Optional , Tuple , Callable 
6+ import  time 
67
78import  mrcfile 
89import  torch 
910import  torch_em 
1011import  torch_em .self_training  as  self_training 
12+ from  torch_em .self_training .logger  import  SelfTrainingTensorboardLogger 
1113from  elf .io  import  open_file 
1214from  sklearn .model_selection  import  train_test_split 
1315
1921from  ..inference .util  import  _Scaler 
2022
2123class  NewPseudoLabeler (self_training .DefaultPseudoLabeler ):
22-     """Compute pseudo labels based on model predictions, typically  from a teacher model . 
23-         By default, assumes that the first channel contains the transformed data and the second channel contains the background mask. # TODO update description  
24+     """Subclass of DefaultPseudoLabeler, which can subtract background  from the pseudo labels if a background mask is provided . 
25+         By default, assumes that the first channel contains the transformed raw  data and the second channel contains the background mask. 
2426
2527    Args: 
2628        activation: Activation function applied to the teacher prediction. 
@@ -32,23 +34,23 @@ class NewPseudoLabeler(self_training.DefaultPseudoLabeler):
3234        confidence_mask_channel: A specific channel to use for computing the confidence mask. 
3335            By default the confidence mask is computed across all channels independently. 
3436            This is useful, if only one of the channels encodes a probability. 
35-         raw_channel: # TODO add description  
36-         background_mask_channel: # TODO add description  
37+         raw_channel: Channel index of the raw data, which will be used as input to the teacher model  
38+         background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels.  
3739    """ 
3840    def  __init__ (
3941        self ,
4042        activation : Optional [torch .nn .Module ] =  None ,
4143        confidence_threshold : Optional [float ] =  None ,
4244        threshold_from_both_sides : bool  =  True ,
4345        confidence_mask_channel : Optional [int ] =  None ,
44-         raw_channel : Optional [int ] =  0 ,  
46+         raw_channel : Optional [int ] =  0 ,
4547        background_mask_channel : Optional [int ] =  1 ,
4648    ):
4749        super ().__init__ (activation , confidence_threshold , threshold_from_both_sides )
50+         self .confidence_mask_channel  =  confidence_mask_channel 
4851        self .raw_channel  =  raw_channel 
4952        self .background_mask_channel  =  background_mask_channel 
50-         self .confidence_mask_channel  =  confidence_mask_channel 
51-     
53+         
5254    def  _subtract_background (self , pseudo_labels : torch .Tensor , background_mask : torch .Tensor ):
5355        bool_mask  =  background_mask .bool ()
5456        return  pseudo_labels .masked_fill (bool_mask , 0 )
@@ -63,10 +65,12 @@ def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tens
6365        Returns: 
6466            The pseudo-labels. 
6567        """ 
66-         if  self .background_mask_channel  is  not   None :
67-             if  input_ .ndim  !=  5 :
68-                 raise  ValueError (f"Expect data with 5 dimensions (B, C, D, H, W), got shape { input_ .shape }  ." ) 
69-     
68+         if  input_ .ndim  !=  5 :
69+             raise  ValueError (f"Expect data with 5 dimensions (B, C, D, H, W), got shape { input_ .shape }  ." )
70+         
71+         has_background_mask  =  input_ .shape [1 ] >  1  
72+ 
73+         if  has_background_mask :
7074            if  self .background_mask_channel  >  input_ .shape [1 ]:
7175                raise  ValueError (f"Channel index { self .background_mask_channel }   is out of bounds for shape { input_ .shape }  ." )
7276
@@ -88,11 +92,112 @@ def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tens
8892                size  =  (pseudo_labels .shape [0 ], pseudo_labels .shape [1 ], * ([- 1 ] *  (pseudo_labels .ndim  -  2 )))
8993                label_mask  =  label_mask .expand (* size )
9094
91-         if  self . background_mask_channel   is   not   None :   
95+         if  has_background_mask : 
9296            pseudo_labels  =  self ._subtract_background (pseudo_labels , background_mask )
9397
9498        return  pseudo_labels , label_mask 
9599
100+ class  NewMeanTeacherTrainer (self_training .MeanTeacherTrainer ):
101+     """Subclass of MeanTeacherTrainer, updated to handle cases where the background mask is provided.  
102+     Once the pseudo labels are computed, the second channel of the teacher input is dropped, if it exists. 
103+     The second channel of the student input is also dropped, if it exists, since it is not needed for training. 
104+ 
105+     Args: 
106+         activation: Activation function applied to the teacher prediction. 
107+         confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. 
108+             If None is given no mask will be computed. 
109+         threshold_from_both_sides: Whether to include both values bigger than the threshold 
110+             and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. 
111+             The former should be used for binary labels, the latter for for multiclass labels. 
112+         confidence_mask_channel: A specific channel to use for computing the confidence mask. 
113+             By default the confidence mask is computed across all channels independently. 
114+             This is useful, if only one of the channels encodes a probability. 
115+         raw_channel: Channel index of the raw data to be used as input to the teacher model. 
116+         background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels. 
117+     """ 
118+     def  __init__ (
119+         self ,
120+         model : torch .nn .Module ,
121+         unsupervised_train_loader : torch .utils .data .DataLoader ,
122+         unsupervised_loss : Callable ,
123+         pseudo_labeler : Callable ,
124+         supervised_train_loader : Optional [torch .utils .data .DataLoader ] =  None ,
125+         unsupervised_val_loader : Optional [torch .utils .data .DataLoader ] =  None ,
126+         supervised_val_loader : Optional [torch .utils .data .DataLoader ] =  None ,
127+         supervised_loss : Optional [Callable ] =  None ,
128+         unsupervised_loss_and_metric : Optional [Callable ] =  None ,
129+         supervised_loss_and_metric : Optional [Callable ] =  None ,
130+         logger = SelfTrainingTensorboardLogger ,
131+         momentum : float  =  0.999 ,
132+         reinit_teacher : Optional [bool ] =  None ,
133+         sampler : Optional [Callable ] =  None ,
134+         ** kwargs ,
135+     ):
136+         super ().__init__ (model , unsupervised_train_loader , unsupervised_loss , pseudo_labeler ,
137+                          supervised_train_loader , unsupervised_val_loader , supervised_val_loader ,
138+                          supervised_loss , unsupervised_loss_and_metric , supervised_loss_and_metric ,
139+                          logger , momentum , reinit_teacher , sampler , ** kwargs )
140+ 
141+     def  _train_epoch_unsupervised (self , progress , forward_context , backprop ):
142+         self .model .train ()
143+ 
144+         n_iter  =  0 
145+         t_per_iter  =  time .time ()
146+ 
147+         # Sample from both the supervised and unsupervised loader. 
148+         for  xu1 , xu2  in  self .unsupervised_train_loader :
149+             
150+             # Assuming shape (B, C, D, H, W), only keep the first channel for xu2 (student input). 
151+             if  xu2 .shape [1 ] >  1 :
152+                 xu2  =  xu2 [:, :1 ].contiguous ()
153+ 
154+             xu1 , xu2  =  xu1 .to (self .device , non_blocking = True ), xu2 .to (self .device , non_blocking = True )
155+ 
156+             teacher_input , model_input  =  xu1 , xu2 
157+             
158+             with  forward_context (), torch .no_grad ():
159+                 # Compute the pseudo labels. 
160+                 pseudo_labels , label_filter  =  self .pseudo_labeler (self .teacher , teacher_input )
161+ 
162+             # Drop the second channel for xu1 (teacher input) after computing the pseudo labels. 
163+             if  xu1 .shape [1 ] >  1 :
164+                 xu1  =  xu1 [:, :1 ].contiguous ()
165+ 
166+             # If we have a sampler then check if the current batch matches the condition for inclusion in training. 
167+             if  self .sampler  is  not   None :
168+                 keep_batch  =  self .sampler (pseudo_labels , label_filter )
169+                 if  not  keep_batch :
170+                     continue 
171+ 
172+             self .optimizer .zero_grad ()
173+             # Perform unsupervised training 
174+             with  forward_context ():
175+                 loss  =  self .unsupervised_loss (self .model , model_input , pseudo_labels , label_filter )
176+             backprop (loss )
177+ 
178+             if  self .logger  is  not   None :
179+                 with  torch .no_grad (), forward_context ():
180+                     pred  =  self .model (model_input ) if  self ._iteration  %  self .log_image_interval  ==  0  else  None 
181+                 self .logger .log_train_unsupervised (
182+                     self ._iteration , loss , xu1 , xu2 , pred , pseudo_labels , label_filter 
183+                 )
184+                 lr  =  [pm ["lr" ] for  pm  in  self .optimizer .param_groups ][0 ]
185+                 self .logger .log_lr (self ._iteration , lr )
186+                 if  self .pseudo_labeler .confidence_threshold  is  not   None :
187+                     self .logger .log_ct (self ._iteration , self .pseudo_labeler .confidence_threshold )
188+ 
189+             with  torch .no_grad ():
190+                 self ._momentum_update ()
191+ 
192+             self ._iteration  +=  1 
193+             n_iter  +=  1 
194+             if  self ._iteration  >=  self .max_iteration :
195+                 break 
196+             progress .update (1 )
197+ 
198+         t_per_iter  =  (time .time () -  t_per_iter ) /  n_iter 
199+         return  t_per_iter 
200+ 
96201def  mean_teacher_adaptation (
97202    name : str ,
98203    unsupervised_train_paths : Tuple [str ],
@@ -114,13 +219,11 @@ def mean_teacher_adaptation(
114219    train_sample_mask_paths : Optional [Tuple [str ]] =  None ,
115220    val_sample_mask_paths : Optional [Tuple [str ]] =  None ,
116221    train_background_mask_paths : Optional [Tuple [str ]] =  None ,
117-     train_mask_paths : Optional [Tuple [str ]] =  None ,
118-     val_mask_paths : Optional [Tuple [str ]] =  None ,
119222    patch_sampler : Optional [callable ] =  None ,
120223    pseudo_label_sampler : Optional [callable ] =  None ,
121224    device : int  =  0 ,
122225) ->  None :
123-     """Run domain adaptation  to transfer a network trained on a source domain for a supervised 
226+     """Run domain adapation  to transfer a network trained on a source domain for a supervised 
124227    segmentation task to perform this task on a different target domain. 
125228
126229    We support different domain adaptation settings: 
@@ -163,15 +266,14 @@ def mean_teacher_adaptation(
163266            based on the patch_shape and size of the volumes used for training. 
164267        n_samples_val: The number of val samples per epoch. By default this will be estimated 
165268            based on the patch_shape and size of the volumes used for validation. 
166-         train_sample_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.  
167-         val_sample_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.  
168-         train_background_mask_paths: # TODO add description 
269+         train_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject  
270+             patches for training. 
271+         val_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject  
272+             patches for validation.  
273+         train_background_mask_paths: Filepaths to the background masks used for training. 
274+             Background masks are used to subtract background from the pseudo labels before the forward pass.  
169275        patch_sampler: A sampler for rejecting patches based on a defined conditon.  
170276        pseudo_label_sampler: A sampler for rejecting pseudo-labels based on a defined condition. 
171-         train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.  
172-         val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.  
173-         patch_sampler: Accept or reject patches based on a condition. 
174-         pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.  
175277        device: GPU ID for training.  
176278    """ 
177279    assert  (supervised_train_paths  is  None ) ==  (supervised_val_paths  is  None )
@@ -192,7 +294,7 @@ def mean_teacher_adaptation(
192294        if  os .path .isdir (source_checkpoint ):
193295            model  =  torch_em .util .load_model (source_checkpoint )
194296        else :
195-             model  =  torch .load (source_checkpoint ,  weights_only = False )
297+             model  =  torch .load (source_checkpoint )
196298        reinit_teacher  =  False 
197299
198300    optimizer  =  torch .optim .Adam (model .parameters (), lr = 1e-4 )
@@ -206,7 +308,7 @@ def mean_teacher_adaptation(
206308
207309    loss  =  self_training .DefaultSelfTrainingLoss ()
208310    loss_and_metric  =  self_training .DefaultSelfTrainingLossAndMetric ()
209- 
311+     
210312    unsupervised_train_loader  =  get_unsupervised_loader (
211313        data_paths = unsupervised_train_paths , 
212314        raw_key = raw_key , 
@@ -215,7 +317,6 @@ def mean_teacher_adaptation(
215317        n_samples = n_samples_train , 
216318        sample_mask_paths = train_sample_mask_paths , 
217319        background_mask_paths = train_background_mask_paths ,
218-         sample_mask_paths = train_mask_paths , 
219320        sampler = patch_sampler 
220321    )
221322    unsupervised_val_loader  =  get_unsupervised_loader (
@@ -226,7 +327,6 @@ def mean_teacher_adaptation(
226327        n_samples = n_samples_val , 
227328        sample_mask_paths = val_sample_mask_paths , 
228329        background_mask_paths = None ,
229-         sample_mask_paths = val_mask_paths , 
230330        sampler = patch_sampler 
231331    )
232332
0 commit comments