Skip to content

Commit bbacefe

Browse files
committed
update trainer to compute entropy on validation model
1 parent 511b209 commit bbacefe

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

trainer/src/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ def start_training(self, config):
220220
self.model = create_first_model_with_random_weights(model_dir)
221221
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01,
222222
momentum=0.99, nesterov=True)
223+
224+
# validation model needs to be stored for active learning entropy computation.
225+
self.validation_model = copy.deepcopy(self.model)
226+
223227
self.model.train()
224228
self.training = True
225229

@@ -326,7 +330,7 @@ def train_one_epoch(self):
326330
project_dir=project_dir,
327331
train_annot_dir=self.train_config['train_annot_dir'],
328332
val_annot_dir=self.train_config['val_annot_dir'],
329-
model=self.model,
333+
model=self.validation_model,
330334
in_w=self.in_w,
331335
out_w=self.out_w
332336
)
@@ -379,6 +383,8 @@ def validation(self):
379383
cur_metrics['f1'], prev_metrics['f1'])
380384
if was_saved:
381385
self.epochs_without_progress = 0
386+
self.validation_model = copy.deepcopy(self.model) # ✅ update to new best model
387+
382388

383389
# Clear uncertainty cache
384390
project_dir = os.path.dirname(self.train_config['seg_dir'])

0 commit comments

Comments
 (0)