Skip to content

Commit b9558fc

Browse files
fix hard label
1 parent c11908f commit b9558fc

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ddranking/metrics/hard_label.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def compute_hard_label_metrics(self, model, image_tensor, image_path, lr, hard_l
196196
if epoch > 0.8 * self.num_epochs and (epoch + 1) % self.test_interval == 0:
197197
metric = validate(
198198
model=model,
199-
loader=self.test_loader,
199+
loader=self.test_loader_real if mode == 'real' else self.test_loader_syn,
200200
device=self.device
201201
)
202202
if metric['top1'] > best_acc1:
@@ -256,17 +256,17 @@ def compute_metrics(self, image_tensor: Tensor=None, image_path: str=None, hard_
256256
)
257257
full_data_hard_label_acc = self.compute_hard_label_metrics(
258258
model=model,
259-
image_tensor=self.images_train,
259+
image_tensor=None,
260260
image_path=None,
261261
lr=self.default_lr,
262-
hard_labels=self.labels_train,
262+
hard_labels=None,
263263
mode='real'
264264
)
265265
del model
266266
print(f"Full data hard label acc: {full_data_hard_label_acc:.2f}%")
267267

268268
print("Caculating random data hard label metrics...")
269-
random_images, random_data_hard_labels = get_random_images(self.images_train, self.labels_train, self.class_indices_train, self.ipc)
269+
random_images, random_data_hard_labels = get_random_images(self.dst_train, self.class_indices, self.ipc)
270270
random_data_hard_label_acc, best_lr = self.hyper_param_search_for_hard_label(
271271
image_tensor=random_images,
272272
image_path=None,

0 commit comments

Comments
 (0)