Skip to content

Commit 6b78a24

Browse files
Fix some issues
1 parent de75469 commit 6b78a24

File tree

3 files changed

+125
-62
lines changed

3 files changed

+125
-62
lines changed

ddranking/metrics/hard_label.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,18 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
5050
num_workers = self.config.get('num_workers')
5151
device = self.config.get('device')
5252

53-
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset,
54-
real_data_path,
55-
im_size,
56-
use_zca,
57-
custom_train_trans,
58-
custom_val_trans,
59-
device)
60-
self.images_train, self.labels_train, self.class_indices_train = self.load_real_data(dst_train, class_map, num_classes)
61-
self.test_loader = DataLoader(dst_test, batch_size=real_batch_size, num_workers=num_workers, shuffle=False)
53+
channel, im_size, num_classes, dst_train, dst_test_real, dst_test_syn, class_map, class_map_inv = get_dataset(dataset,
54+
real_data_path,
55+
im_size,
56+
use_zca,
57+
custom_val_trans,
58+
device)
59+
# self.images_train, self.labels_train, self.class_indices_train = self.load_real_data(dst_train, class_map, num_classes)
60+
self.class_indices = self.get_class_indices(dst_train, class_map, num_classes)
61+
self.dst_train = dst_train
62+
63+
self.test_loader_real = DataLoader(dst_test_real, batch_size=real_batch_size, num_workers=num_workers, shuffle=False)
64+
self.test_loader_syn = DataLoader(dst_test_syn, batch_size=syn_batch_size, num_workers=num_workers, shuffle=False)
6265

6366
# data info
6467
self.im_size = im_size
@@ -115,6 +118,15 @@ def load_real_data(self, dataset, class_map, num_classes):
115118

116119
return images_all, labels_all, class_indices
117120

121+
def get_class_indices(self, dataset, class_map, num_classes):
122+
class_indices = [[] for c in range(num_classes)]
123+
for i, (_, label) in enumerate(dataset):
124+
if torch.is_tensor(label):
125+
label = label.item()
126+
true_label = class_map[label]
127+
class_indices[true_label].append(i)
128+
return class_indices
129+
118130
def hyper_param_search_for_hard_label(self, image_tensor, image_path, hard_labels, mode='real'):
119131
lr_list = [0.001, 0.005, 0.01, 0.05, 0.1]
120132
best_acc = 0
@@ -144,11 +156,13 @@ def hyper_param_search_for_hard_label(self, image_tensor, image_path, hard_label
144156
return best_acc, best_lr
145157

146158
def compute_hard_label_metrics(self, model, image_tensor, image_path, lr, hard_labels, mode='real'):
147-
148-
if image_tensor is None:
149-
hard_label_dataset = datasets.ImageFolder(root=image_path, transform=self.custom_train_trans)
159+
if mode == 'real':
160+
hard_label_dataset = self.dst_train
150161
else:
151-
hard_label_dataset = TensorDataset(image_tensor, hard_labels)
162+
if image_tensor is None:
163+
hard_label_dataset = datasets.ImageFolder(root=image_path, transform=self.custom_train_trans)
164+
else:
165+
hard_label_dataset = TensorDataset(image_tensor, hard_labels)
152166
train_loader = DataLoader(hard_label_dataset, batch_size=self.real_batch_size if mode == 'real' else self.syn_batch_size,
153167
num_workers=self.num_workers, shuffle=True)
154168

ddranking/metrics/soft_label.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SoftLabelEvaluator:
2222
def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path: str='./dataset/', ipc: int=10, model_name: str='ConvNet-3',
2323
soft_label_criterion: str='kl', data_aug_func: str='cutmix', aug_params: dict={'beta': 1.0}, soft_label_mode: str='S',
2424
optimizer: str='sgd', lr_scheduler: str='step', temperature: float=1.0, weight_decay: float=0.0005,
25-
momentum: float=0.9, num_eval: int=5, im_size: tuple=(32, 32), num_epochs: int=300, use_zca: bool=False,
25+
momentum: float=0.9, num_eval: int=5, im_size: tuple=(32, 32), num_epochs: int=300, use_zca: bool=False, use_aug_for_hard: bool=False,
2626
real_batch_size: int=256, syn_batch_size: int=256, default_lr: float=0.01, save_path: str=None, stu_use_torchvision: bool=False,
2727
tea_use_torchvision: bool=False, num_workers: int=4, teacher_dir: str='./teacher_models', custom_train_trans: transforms.Compose=None,
2828
custom_val_trans: transforms.Compose=None, device: str="cuda"):
@@ -46,6 +46,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
4646
im_size = self.config.get('im_size')
4747
num_epochs = self.config.get('num_epochs')
4848
use_zca = self.config.get('use_zca')
49+
use_aug_for_hard = self.config.get('use_aug_for_hard')
4950
real_batch_size = self.config.get('real_batch_size')
5051
syn_batch_size = self.config.get('syn_batch_size')
5152
default_lr = self.config.get('default_lr')
@@ -58,15 +59,18 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
5859
teacher_dir = self.config.get('teacher_dir')
5960
device = self.config.get('device')
6061

61-
channel, im_size, num_classes, dst_train, dst_test, class_map, class_map_inv = get_dataset(dataset,
62-
real_data_path,
63-
im_size,
64-
use_zca,
65-
custom_train_trans,
66-
custom_val_trans,
67-
device)
68-
self.images_train, self.labels_train, self.class_indices_train = self.load_real_data(dst_train, class_map, num_classes)
69-
self.test_loader = DataLoader(dst_test, batch_size=real_batch_size, num_workers=num_workers, shuffle=False)
62+
channel, im_size, num_classes, dst_train, dst_test_real, dst_test_syn, class_map, class_map_inv = get_dataset(dataset,
63+
real_data_path,
64+
im_size,
65+
use_zca,
66+
custom_val_trans,
67+
device)
68+
# self.images_train, self.labels_train, self.class_indices_train = self.load_real_data(dst_train, class_map, num_classes)
69+
self.class_indices = self.get_class_indices(dst_train, class_map, num_classes)
70+
self.dst_train = dst_train
71+
72+
self.test_loader_real = DataLoader(dst_test_real, batch_size=real_batch_size, num_workers=num_workers, shuffle=False)
73+
self.test_loader_syn = DataLoader(dst_test_syn, batch_size=syn_batch_size, num_workers=num_workers, shuffle=False)
7074

7175
self.soft_label_mode = soft_label_mode
7276
self.soft_label_criterion = soft_label_criterion
@@ -94,6 +98,7 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
9498
self.num_workers = num_workers
9599
self.device = device
96100

101+
# data augmentation
97102
if data_aug_func == 'dsa':
98103
self.aug_func = DSA(aug_params)
99104
elif data_aug_func == 'mixup':
@@ -102,7 +107,9 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
102107
self.aug_func = Cutmix(aug_params)
103108
else:
104109
self.aug_func = None
110+
self.use_aug_for_hard = use_aug_for_hard
105111

112+
# save path
106113
if not save_path:
107114
save_path = f"./results/{dataset}/{model_name}/ipc{ipc}/obj_scores.csv"
108115
if not os.path.exists(os.path.dirname(save_path)):
@@ -139,6 +146,15 @@ def load_real_data(self, dataset, class_map, num_classes):
139146

140147
return images_all, labels_all, class_indices
141148

149+
def get_class_indices(self, dataset, class_map, num_classes):
150+
class_indices = [[] for c in range(num_classes)]
151+
for idx, (_, label) in enumerate(dataset.imgs):
152+
if torch.is_tensor(label):
153+
label = label.item()
154+
true_label = class_map[label]
155+
class_indices[true_label].append(idx)
156+
return class_indices
157+
142158
def hyper_param_search_for_hard_label(self, image_tensor, image_path, hard_labels, mode='real'):
143159
lr_list = [0.001, 0.005, 0.01, 0.05, 0.1]
144160
best_acc = 0
@@ -177,7 +193,7 @@ def hyper_param_search_for_soft_label(self, image_tensor, image_path, soft_label
177193
model = build_model(
178194
model_name=self.model_name,
179195
num_classes=self.num_classes,
180-
im_size=self.im_size,
196+
im_size=self.im_size,
181197
pretrained=False,
182198
use_torchvision=self.stu_use_torchvision,
183199
device=self.device
@@ -196,11 +212,13 @@ def hyper_param_search_for_soft_label(self, image_tensor, image_path, soft_label
196212
return best_acc, best_lr
197213

198214
def compute_hard_label_metrics(self, model, image_tensor, image_path, lr, hard_labels, mode='real'):
199-
200-
if image_tensor is None:
201-
hard_label_dataset = datasets.ImageFolder(root=image_path, transform=self.custom_train_trans)
215+
if mode == 'real':
216+
hard_label_dataset = self.dst_train
202217
else:
203-
hard_label_dataset = TensorDataset(image_tensor, hard_labels)
218+
if image_tensor is None:
219+
hard_label_dataset = datasets.ImageFolder(root=image_path, transform=self.custom_train_trans)
220+
else:
221+
hard_label_dataset = TensorDataset(image_tensor, hard_labels)
204222
train_loader = DataLoader(hard_label_dataset, batch_size=self.real_batch_size if mode == 'real' else self.syn_batch_size,
205223
num_workers=self.num_workers, shuffle=True)
206224

@@ -216,15 +234,15 @@ def compute_hard_label_metrics(self, model, image_tensor, image_path, lr, hard_l
216234
loader=train_loader,
217235
loss_fn=loss_fn,
218236
optimizer=optimizer,
219-
aug_func=self.aug_func,
237+
aug_func=self.aug_func if self.use_aug_for_hard else None,
220238
lr_scheduler=lr_scheduler,
221239
tea_model=self.teacher_model,
222240
device=self.device
223241
)
224242
if epoch > 0.8 * self.num_epochs and (epoch + 1) % self.test_interval == 0:
225243
metric = validate(
226244
model=model,
227-
loader=self.test_loader,
245+
loader=self.test_loader_real,
228246
device=self.device
229247
)
230248
if metric['top1'] > best_acc1:
@@ -272,7 +290,7 @@ def compute_soft_label_metrics(self, model, image_tensor, image_path, lr, soft_l
272290
if epoch > 0.8 * self.num_epochs and (epoch + 1) % self.test_interval == 0:
273291
metric = validate(
274292
model=model,
275-
loader=self.test_loader,
293+
loader=self.test_loader_syn,
276294
device=self.device
277295
)
278296
if metric['top1'] > best_acc1:
@@ -326,10 +344,10 @@ def compute_metrics(self, image_tensor: Tensor=None, image_path: str=None, soft_
326344
)
327345
full_data_hard_label_acc = self.compute_hard_label_metrics(
328346
model=model,
329-
image_tensor=self.images_train,
347+
image_tensor=None,
330348
image_path=None,
331349
lr=self.default_lr,
332-
hard_labels=self.labels_train,
350+
hard_labels=None,
333351
mode='real'
334352
)
335353
del model
@@ -362,7 +380,7 @@ def compute_metrics(self, image_tensor: Tensor=None, image_path: str=None, soft_
362380
print(f"Syn data soft label acc: {syn_data_soft_label_acc:.2f}%")
363381

364382
print("Caculating random data soft label metrics...")
365-
random_images, _ = get_random_images(self.images_train, self.labels_train, self.class_indices_train, self.ipc)
383+
random_images, _ = get_random_images(self.dst_train, self.class_indices, self.ipc)
366384
if self.soft_label_mode == 'S':
367385
random_data_soft_labels = self.generate_soft_labels(random_images)
368386
else:

0 commit comments

Comments
 (0)