Skip to content

Commit 8a60f62

Browse files
Fix the training setting for real dataset
1 parent 6b78a24 commit 8a60f62

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

ddranking/metrics/hard_label.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,18 @@ def compute_hard_label_metrics(self, model, image_tensor, image_path, lr, hard_l
167167
num_workers=self.num_workers, shuffle=True)
168168

169169
loss_fn = torch.nn.CrossEntropyLoss()
170-
optimizer = get_optimizer(self.optimizer, model, lr, self.weight_decay, self.momentum)
170+
171+
# We use default optimizer and lr scheduler to train a model on real data. These parameters are empirically set.
172+
if mode == 'real':
173+
if self.model_name.startswith('ConvNet'):
174+
optimizer = get_optimizer('sgd', model, lr, 0.0005, 0.9)
175+
elif self.model_name.startswith('ResNet'):
176+
optimizer = get_optimizer('adamw', model, lr, 0.01, 0.9)
177+
else: # TODO: add more models
178+
optimizer = get_optimizer(self.optimizer, model, lr, self.weight_decay, self.momentum)
179+
else:
180+
optimizer = get_optimizer(self.optimizer, model, lr, self.weight_decay, self.momentum)
181+
# Learning rate scheduler doesn't affect the results too much.
171182
lr_scheduler = get_lr_scheduler(self.lr_scheduler, optimizer, self.num_epochs)
172183

173184
best_acc1 = 0

ddranking/metrics/soft_label.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,16 @@ def compute_hard_label_metrics(self, model, image_tensor, image_path, lr, hard_l
223223
num_workers=self.num_workers, shuffle=True)
224224

225225
loss_fn = torch.nn.CrossEntropyLoss()
226-
optimizer = get_optimizer(self.optimizer, model, lr, self.weight_decay, self.momentum)
226+
# We use default optimizer and lr scheduler to train a model on real data. These parameters are empirically set.
227+
if mode == 'real':
228+
if self.model_name.startswith('ConvNet'):
229+
optimizer = get_optimizer('sgd', model, lr, 0.0005, 0.9)
230+
elif self.model_name.startswith('ResNet'):
231+
optimizer = get_optimizer('adamw', model, lr, 0.01, 0.9)
232+
else: # TODO: add more models
233+
optimizer = get_optimizer(self.optimizer, model, lr, self.weight_decay, self.momentum)
234+
else:
235+
optimizer = get_optimizer(self.optimizer, model, lr, self.weight_decay, self.momentum)
227236
lr_scheduler = get_lr_scheduler(self.lr_scheduler, optimizer, self.num_epochs)
228237

229238
best_acc1 = 0

0 commit comments

Comments
 (0)