Skip to content

Commit ee072fa

Browse files
Fix a bug
1 parent b1c1880 commit ee072fa

File tree

2 files changed

+1
-5
lines changed

2 files changed

+1
-5
lines changed

dd_ranking/metrics/dd_ranking_obj.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def compute_hard_label_metrics(self, model, images, lr, hard_labels):
195195
metric = validate(
196196
model=model,
197197
loader=self.test_loader,
198-
aug_func=self.aug_func,
199198
device=self.device
200199
)
201200
if metric['top1'] > best_acc1:
@@ -240,7 +239,6 @@ def compute_soft_label_metrics(self, model, images, lr, soft_labels):
240239
metric = validate(
241240
model=model,
242241
loader=self.test_loader,
243-
aug_func=self.aug_func,
244242
device=self.device
245243
)
246244
if metric['top1'] > best_acc1:
@@ -500,7 +498,6 @@ def compute_hard_label_metrics(self, model, images, lr, hard_labels):
500498
metric = validate(
501499
model=model,
502500
loader=self.test_loader,
503-
aug_func=self.aug_func,
504501
device=self.device
505502
)
506503
if metric['top1'] > best_acc1:

dd_ranking/utils/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,6 @@ def _backward(_loss):
524524
def validate(
525525
model,
526526
loader,
527-
aug_func=None,
528527
device=torch.device('cuda'),
529528
logging=False,
530529
log_interval=10
@@ -545,7 +544,7 @@ def validate(
545544
last_batch = batch_idx == last_idx
546545
input = input.to(device)
547546
target = target.to(device)
548-
input = aug_func(input)
547+
549548
output = model(input)
550549
if isinstance(output, (tuple, list)):
551550
output = output[0]

0 commit comments

Comments
 (0)