File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed
Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -220,6 +220,10 @@ def start_training(self, config):
220220 self .model = create_first_model_with_random_weights (model_dir )
221221 self .optimizer = torch .optim .SGD (self .model .parameters (), lr = 0.01 ,
222222 momentum = 0.99 , nesterov = True )
223+
224+ # validation model needs to be stored for active learning entropy computation.
225+ self .validation_model = copy .deepcopy (self .model )
226+
223227 self .model .train ()
224228 self .training = True
225229
@@ -326,7 +330,7 @@ def train_one_epoch(self):
326330 project_dir = project_dir ,
327331 train_annot_dir = self .train_config ['train_annot_dir' ],
328332 val_annot_dir = self .train_config ['val_annot_dir' ],
329- model = self .model ,
333+ model = self .validation_model ,
330334 in_w = self .in_w ,
331335 out_w = self .out_w
332336 )
@@ -379,6 +383,8 @@ def validation(self):
379383 cur_metrics ['f1' ], prev_metrics ['f1' ])
380384 if was_saved :
381385 self .epochs_without_progress = 0
386+ self .validation_model = copy .deepcopy (self .model ) # ✅ update to new best model
387+
382388
383389 # Clear uncertainty cache
384390 project_dir = os .path .dirname (self .train_config ['seg_dir' ])
You can’t perform that action at this time.
0 commit comments